在 Keras 中保存與加載模型
無論是在學(xué)習(xí)還是在工作的過程之中,我們都會遇到保存數(shù)據(jù)的情形。
在我們之前的學(xué)習(xí)之中,我們所訓(xùn)練到的模型都沒有經(jīng)過保存,也就是所我們得到的模型的結(jié)構(gòu)和參數(shù)都是存在于內(nèi)存之中的,當(dāng)我們關(guān)閉程序的時候這些模型和參數(shù)都會消失;如果我們想要使用該模型的話就需要再次訓(xùn)練模型。
這顯然是不可取的,因此我們要學(xué)會如何保存模型與加載模型。
1. 定義模型結(jié)構(gòu)
由于我們這節(jié)課的重點(diǎn)在模型的保存,而不是網(wǎng)絡(luò)的結(jié)構(gòu),因此我們使用之前的網(wǎng)絡(luò)結(jié)構(gòu): fashion_mnist 分類的網(wǎng)絡(luò)結(jié)構(gòu)。
具體的網(wǎng)絡(luò)代碼為:
import tensorflow as tf
# 使用內(nèi)置的數(shù)據(jù)集合來加載數(shù)據(jù)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
# 預(yù)處理圖片數(shù)據(jù),使其歸一化
x_train, x_test = x_train / 255.0, x_test / 255.0
def get_model():
# 定義網(wǎng)絡(luò)結(jié)構(gòu)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
return model
modle = get_model()
# 編譯模型
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
在這里我們不僅僅定義了網(wǎng)絡(luò)的基本結(jié)構(gòu),同時也載入了基本的圖片數(shù)據(jù),從而便于后面的訓(xùn)練以及模型保存等操作。
2. 在訓(xùn)練結(jié)束后保存模型參數(shù)、加載模型參數(shù)
我們可以在訓(xùn)練之前直接保存模型參數(shù),但是因?yàn)檫@樣的參數(shù)是未經(jīng)過訓(xùn)練的,因此沒有太有價值的意義,因此我們在保存模型之前要先訓(xùn)練模型。
我們可以通過以下代碼來訓(xùn)練模型:
# 訓(xùn)練模型
model.fit(x_train, y_train, epochs=5)
訓(xùn)練的過程之中我們可以得到如下的輸出:
Epoch 1/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4870 - accuracy: 0.8288
Epoch 2/5
1875/1875 [==============================] - 5s 2ms/step - loss: 0.3616 - accuracy: 0.8679
Epoch 3/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3256 - accuracy: 0.8795
Epoch 4/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3006 - accuracy: 0.8883
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2867 - accuracy: 0.8931
2.1 保存模型參數(shù)
在訓(xùn)練結(jié)束之后我們可以手動進(jìn)行模型參數(shù)的保存:
model.save_weights('./checkpoints/ckpt')
通過這樣的操作,我們便可以將我們模型的參數(shù)保存至當(dāng)前目錄的 “checkpoints” 文件夾下面,并且名命為 ckpt 。
我們可以查看該文件夾下面的文件,可以看到文件夾下面包括三個文件:
79 checkpoint
1.2K checkpoints.index
2.4M checkpoints.data-00000-of-00001
這三個文件之中保存的就是我們的模型的參數(shù)。
2.2 加載模型參數(shù)
如果我們需要加載我們的模型,我們只需要經(jīng)過以下兩步即可:
- 定義網(wǎng)絡(luò)結(jié)構(gòu);
- 按照保存路徑來載入?yún)?shù)。
具體代碼如下:
# 創(chuàng)建模型結(jié)構(gòu)
model = get_model()
# 加載參數(shù)
model.load_weights('./checkpoints/ckpt')
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 評估模型
model.evaluate(x_test, y_test, verbose=2)
我們可以看到模型的輸出為:
313/313 - 0s - loss: 0.3448 - accuracy: 0.8758
[0.34482061862945557, 0.8758000135421753]
說明我們的模型參數(shù)已經(jīng)成功加載。
3. 使用回調(diào)保存模型參數(shù)
前面我們知道了如何在模型訓(xùn)練結(jié)束后保存模型,那么如何讓模型在訓(xùn)練的過程中自動保存模型呢?
那便就需要用到 TensorFlow 的**“回調(diào)函數(shù)”**這個功能,這個功能允許我們定義一系列的事件,并讓其在訓(xùn)練的過程之中執(zhí)行。
在這個例子之中,我們可以讓它在每個 Epoch 結(jié)束的時候保存模型參數(shù)。
于是我們首先定義了模型保存的回調(diào)函數(shù),然后我們又在在 fit 函數(shù)之中使用 callbacks 參數(shù)來將其傳入。
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath='./checkpoints2/ckpt', save_weights_only=True)
model.fit(x_train,
y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[cp_callback])
我們可以看到,在每個 Epoch 結(jié)束后,模型都會進(jìn)行模型參數(shù)的保存:
Epoch 6/10
1868/1875 [============================>.] - ETA: 0s - loss: 0.2704 - accuracy: 0.8985
Epoch 00006: saving model to ./checkpoints2/ckpt
INFO:tensorflow:Assets written to: ./checkpoints2/ckpt/assets
于是我們便可以使得模型能夠自動地保存模型參數(shù)。
cp_callback 中的幾個參數(shù)大家需要注意一下:
- file_path: 與手動保存模型一樣,定義了模型參數(shù)保存的路徑;
- save_weights_only: 是否只保存模型參數(shù),一般而言只保存參數(shù)的文件會比全部保存的文件小很多,因此我們一般只是保存網(wǎng)絡(luò)參數(shù)。
這樣可以避免因?yàn)橐馔馇闆r導(dǎo)致程序意外停止時,前面所有的訓(xùn)練都前功盡棄的情況。因?yàn)槲覀兛梢约虞d最近一次保存的模型繼續(xù)訓(xùn)練。
如果想要加載模型,那么便和手動加載模型一樣即可:
model.load_weights('./checkpoints2/ckpt')
4. 保存模型與保存參數(shù)
前面的保存都是只保存網(wǎng)絡(luò)中的各種參數(shù),而沒有保存網(wǎng)絡(luò)的模型。相比較而言而這主要有以下差別:
- 保存參數(shù)的文件較小,而保存整個模型的文件較大;
- 加載參數(shù)速度較快,而加載整個模型較慢;
- 保存參數(shù)不包含網(wǎng)絡(luò)結(jié)構(gòu),而保存整個模型則包含網(wǎng)絡(luò)的結(jié)構(gòu)。
4.1 在訓(xùn)練結(jié)束后手動保存與加載整個模型
和之前的操作一樣,只是我們需要換一下保存的API函數(shù):
model.save('saved_model/model1')
當(dāng)我們需要加載模型的時候,我們需要使用以下方法來加載模型:
model = tf.keras.models.load_model('saved_model/model1')
4.2 在回調(diào)之中保存整個模型
在回調(diào)之中保存整個模型比較簡單,我們只需要將 save_weights_only 參數(shù)設(shè)置為 False 即可:
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath='./checkpoints3/ckpt', save_weights_only=False)
model.fit(x_train,
y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[cp_callback])
5. 小結(jié)
這節(jié)課之中我們主要學(xué)習(xí)了如何進(jìn)行模型的保存與加載,同時我們又深入了解了保存模型與保存參數(shù)的區(qū)別以及它們具體的實(shí)現(xiàn)方式。