在 Keras 中進行模型的評估
在我們進行模型訓練的過程之中總是避免不了進行模型的評估,從而判別一個模型是好是壞,進而幫助我們選擇較好的模型進行使用。
那么我們這節(jié)課就來學習一下如何進行模型的評估。這節(jié)課是在之前的 fashion_mnist 數(shù)據(jù)集合實驗的基礎上進行進一步的改進,常用模型的評估方法大致分為兩類:
- 根據(jù)訓練過程中的指標進行評估;
- 訓練結束后在驗證集合或者測試集合上進行測試。
總體而言,第二種方法是采用較多的方法,因為模型訓練的目的不是在已知的數(shù)據(jù)集合上表現(xiàn)良好,而是要在未知的數(shù)據(jù)集合上表現(xiàn)良好。
1. 根據(jù)訓練過程中的指標進行評估
在訓練過程中根據(jù)指標進行評估的時候大致可以分為兩個類別:
- 根據(jù)損失函數(shù)進行評價;
- 根據(jù)普通的指標進行評價。
1· 根據(jù)損失函數(shù)進行評價
根據(jù)損失函數(shù)評價比較簡單,因為損失函數(shù)是所有的訓練過程都需要定義的,而損失函數(shù)也會在訓練的過程之中自動記錄與保存。
對于所有的損失函數(shù)而言,損失函數(shù)越小,表示我們的模型越精確。我們平常一些常見的損失函數(shù)包括:
- MAE:均絕對誤差,用于回歸任務學習的損失函數(shù);直觀地可以理解為誤差的的均值;
- MSE:均方誤差,用于回歸任務學習的損失函數(shù);與 MAE 相似,直觀地可以理解為誤差的平方的均值;
- Binary_CrossEntropy:二元交叉熵,用于二分類學習的損失函數(shù),描述的是標簽和預測值的差距;
- Categorical_CrossEntropy:交叉熵,與二元交叉熵類似,只是用于多分類的任務的損失函數(shù)。
在使用的時候要首先在模型編譯的時候指定損失函數(shù),在后面的訓練過程中 TensorFlow 會幫助我們自動記錄損失函數(shù)的變化。比如以下的示例:
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy')
在訓練的過程之中,我們可以根據(jù)日志的輸出來查看損失函數(shù)的變化:
......
Epoch 2/5
1875/1875 [==============================] - 5s 2ms/step - loss: 0.3616
Epoch 3/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3256
Epoch 4/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3006
......
一般而言損失函數(shù)的值會不斷變小,如果損失函數(shù)變大或者不變,則表示我們的模型出錯,抑或是獲得的數(shù)據(jù)出錯。
2. 根據(jù)普通的指標進行評價
如果要使用普通的評價指標,比如準確率,那么我們需要在模型的編譯過程之中使用 metrics 參數(shù)來設置我們需要追蹤的指標。比如如下例子:
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
在上面的代碼之中,我們規(guī)定了準確率這個指標,那么在訓練的過程之中 TensorFlow 便會幫助我們評價該指標,并將結果在日志中輸出。比如如下輸出:
Epoch 2/5
1875/1875 [==============================] - 5s 2ms/step - loss: 0.3616 - accuracy: 0.8679
Epoch 3/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3256 - accuracy: 0.8795
Epoch 4/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3006 - accuracy: 0.8883
當然我們可以追蹤的指標不止準確率這一個,我們在TensorFlow可以經(jīng)常用到的指標主要有:
- Accuracy:準確率,用于分類任務;
- Mean: 平均值;
- TruePositives:真正例的數(shù)量,用于二分類任務,(真正例:實際類別和預測類別都為正,簡寫 TP );
- TrueNegatives:真負例的數(shù)量,用于二分類任務,(真反例:實際類別和預測類別都為負,簡寫 TN )
- FalsePositives:假正例的數(shù)量,用于二分類任務,(假正例:預測為正,實際為負,簡寫 FP );
- FalseNegatives:假負例的數(shù)量,用于二分類任務,(假反例:預測為負,實際為正,簡寫 FN );
- Precision:精確率,用于二分類任務,Precision = TP/(TP+FP));
- Recall:召回率,用于二分類任務,Recall = TP/(TP+FN);
- AUC:用于二分類任務的一個指標,可以理解為正樣本的預測值大于負樣本的概率;
- MSE:均方誤差,用于回歸任務,可作為損失函數(shù);
- MAE:均絕對誤差,用于回歸任務,可以作為損失函數(shù);
- RMSE:均方根誤差,用于回歸任務,可作為損失函數(shù),由MSE開方即可得到;
2. 在驗證集合或者測試集合上進行測試
在測試集合上進行驗證需要我們首先在數(shù)據(jù)集合中預留出一定的測試集合,一般而言,我們會將所有數(shù)據(jù)的 80% 用于訓練集合,而剩下的 20% 用于測試集合。
在測試集合上我們主要是使用 evaluate 方法進行模型的評估:
model.evaluate(x_test, y_test)
其中假設我們的模型已經(jīng)經(jīng)過,而 x_test 與 y_test 分別是測試集合的數(shù)據(jù)和標簽。
我們可以得到如下輸出:
313/313 - 0s - loss: 0.3444
0.34437522292137146
如果我們在模型編譯的過程之中添加了指標,比如準確率:
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
那么我們會得到輸出:
313/313 - 0s - loss: 0.3836 - accuracy: 0.8792
[0.3835919201374054, 0.8791999816894531]
可以看到,evaluate 方法會記錄損失函數(shù)和編譯中指定的指標。而在 evaluate 函數(shù)中返回的指標一般可以作為我們參考的較為可靠的依據(jù)。
3. 小結
在這節(jié)課之中我們學習了如何進行模型的評估,主要包括如何在訓練的過程中進行評估以及如何在訓練結束后進行評估。