TensorFlow 快速入門示例
因為本課程是以案例為驅(qū)動進(jìn)行框架的應(yīng)用與講解,因此我們這一節(jié)以一個簡單的示例來幫助大家了解 TensorFlow 框架的使用方法以及基本的程序框架與流程。
既然要進(jìn)行模型的構(gòu)建與訓(xùn)練,那么數(shù)據(jù)集便是必不可少的一部分。因為所有的模型都是建立在一定的數(shù)據(jù)集合之上的。這節(jié)課我們便采用一個叫做 fashion_mnist 數(shù)據(jù)集合進(jìn)行模型的構(gòu)建與訓(xùn)練。
1. 什么是 fashion_mnist 數(shù)據(jù)集合
作為機(jī)器學(xué)習(xí)中最基本的數(shù)據(jù)集合之一, fashion_mnist 數(shù)據(jù)集一直是入門者做程序測試的首選的數(shù)據(jù)集,相比較傳統(tǒng)的 mnist 數(shù)據(jù)集而言,fashion_mnist 數(shù)據(jù)集更加豐富,能夠更好的反映網(wǎng)絡(luò)模型的構(gòu)建的效果。
fashion_mnist 數(shù)據(jù)集合是一個包含 70000 個數(shù)據(jù)的數(shù)據(jù)集合,其中60000 條數(shù)據(jù)為訓(xùn)練集合,10000 條數(shù)據(jù)為測試集合;每個數(shù)據(jù)都是 28*28 的灰度圖片數(shù)據(jù),而每個數(shù)據(jù)的標(biāo)簽分為 10 個類別。其中的幾條數(shù)據(jù)具體如下圖所示(圖片來自于 TensorFlow 官方 API 文檔)。
我們要做的就是如何根據(jù)輸入的圖片訓(xùn)練模型,從而使得模型可以根據(jù)輸入的圖片來預(yù)測其屬于哪一個類別。
fashion_mnist 數(shù)據(jù)集合的 10 個類別為:
["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
2. TensorFlow一般程序結(jié)構(gòu)
TensorFlow 一般的程序的結(jié)構(gòu)都是以下的順序:
- 引入所需要的包
- 加載并預(yù)處理數(shù)據(jù)
- 編寫模型結(jié)構(gòu)
- 編譯模型或 Build 模型
- 訓(xùn)練模型與保存模型
- 評估模型
在這個簡單的示例之中我們不會涉及到模型的保存與加載,我們只是帶領(lǐng)大家熟習(xí)一下程序的整體結(jié)構(gòu)即可。
具體的程序代碼為:
import tensorflow as tf
# 使用內(nèi)置的數(shù)據(jù)集合來加載數(shù)據(jù)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
# 預(yù)處理圖片數(shù)據(jù),使其歸一化
x_train, x_test = x_train / 255.0, x_test / 255.0
# 定義網(wǎng)絡(luò)結(jié)構(gòu)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
# 編譯模型
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 訓(xùn)練模型
model.fit(x_train, y_train, epochs=5)
# 評估模型
model.evaluate(x_test, y_test)
接下來讓我們仔細(xì)地看一下這些代碼到底干了什么。
在該程序之中,我們首先使用 tf.keras 中的 datasets 載入了fashion_mnist 數(shù)據(jù)集合,該函數(shù)返回的是兩個元組:
- 第一個元組為(訓(xùn)練數(shù)據(jù)的圖片,訓(xùn)練數(shù)據(jù)的標(biāo)簽)
- 第二個元組為(測試數(shù)據(jù)的圖片,測試數(shù)據(jù)的標(biāo)簽)
因此我們能夠使用兩個元組來接收我們要訓(xùn)練的數(shù)據(jù)集合。
然后我們對圖片進(jìn)行了預(yù)處理:
x_train, x_test = x_train / 255.0, x_test / 255.0
在機(jī)器學(xué)習(xí)之中,我們一般將我們的輸入數(shù)據(jù)規(guī)范到 [0 ,1] 之間,因為這樣會讓模型的訓(xùn)練效果更好。又因為圖片數(shù)據(jù)的每個像素都是 [0, 255] 的整數(shù),因此我們可以將所有的圖片數(shù)據(jù)除以 255 ,從而進(jìn)行歸一化。
接下來我們便構(gòu)建了我們的模型,我們的模型由三層組成:
- Flatten 層,這一層負(fù)責(zé)將二維的圖片數(shù)據(jù)變成一維的數(shù)組數(shù)據(jù),比如我們輸入的圖片數(shù)據(jù)為 28*28 的二維數(shù)組,那么 Flatten 層將會把其變?yōu)殚L度為 784 的一維數(shù)組。
- Dense 層,全連接層,這一層的單元數(shù)為 10 個,分別對應(yīng)著我們的 10 個類別標(biāo)簽,激活函數(shù)為 “softmax” ,表示它會計算每個類別的可能性,從而取可能性最大的類別作為輸出的結(jié)果。
然后我們便進(jìn)行了模型的編譯工作,在編譯的過程中我們有以下幾點需要注意:
- 優(yōu)化器的選擇,優(yōu)化器代表著如何對網(wǎng)絡(luò)中的參數(shù)進(jìn)行優(yōu)化,這里采用的是 “adam” 優(yōu)化器,也是一種最普遍的優(yōu)化器。
- 損失函數(shù),損失函數(shù)意味著我們?nèi)绾螌Α?strong>模型判斷錯誤的懲罰”的衡量方式;換句話說,我們可以暫且理解成損失函數(shù)表示“模型判斷出錯的程度”。對于這種分類的問題,我們一般采用的是 “sparse_categorical_crossentropy” 交叉熵來衡量。
- Metrics,表示我們在訓(xùn)練的過程中要記錄的數(shù)據(jù),在這里我們記錄了 “accuracy” ,也就是準(zhǔn)確率。
再者我們進(jìn)行模型的訓(xùn)練,我們使用我們預(yù)先加載好的數(shù)據(jù)進(jìn)行模型的訓(xùn)練,在這里我們設(shè)置訓(xùn)練的循環(huán)數(shù)( epoch )為 5,表示我們會在數(shù)據(jù)集上循環(huán) 5 次。
最后我們進(jìn)行模型的評估,我們使用 x_test, y_test 對我們的模型進(jìn)行相應(yīng)的評估。
3. 程序的輸出
通過運行上面的程序,我們可以得到下面的輸出:
Epoch 1/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.4865 - accuracy: 0.8283
Epoch 2/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.3644 - accuracy: 0.8693
Epoch 3/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.3287 - accuracy: 0.8795
Epoch 4/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.3048 - accuracy: 0.8874
Epoch 5/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.2864 - accuracy: 0.8938
313/313 - 0s - loss: 0.3474 - accuracy: 0.8752
[0.3474471867084503, 0.8751999735832214]
由此可以看出,在訓(xùn)練集合上我們可以得到的最高的準(zhǔn)確率為 87.52%,在測試集合上面的準(zhǔn)確率為 87.519997%。
4. 小結(jié)
TesnorFlow 程序的構(gòu)建主要分為三大部分:數(shù)據(jù)預(yù)處理、模型構(gòu)建、模型訓(xùn)練。
而在以后的實踐過程中我們也總是離不開這個程序順序,更加深入的定制化無非就是在這三個大的過程中增添一些細(xì)節(jié),因此我們大家要謹(jǐn)記這個總體步驟。