第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

全部開發(fā)者教程

TensorFlow 入門教程

首頁 慕課教程 TensorFlow 入門教程 TensorFlow 入門教程 在 TensorFlow 之中進行遷移學習

在 TensorFlow 之中進行遷移學習

在之前的學習之中,我們都是從定義模型開始,逐步的獲取數(shù)據(jù)并且對數(shù)據(jù)進行處理,最終訓練模型以達到一個良好的效果。這些任務都是從零開始訓練的例子,那么我們能不能使用別人已經(jīng)訓練好的模型來幫助我們來進行相似的工作呢?答案是肯定的,這就是我們這節(jié)課要學習到的 “遷移學習”。

1. 什么是遷移學習

遷移學習,顧名思義,就是將學習任務遷移的意思。在實際的應用之中,我們遇到的好多學習任務都具有很強的相似性,比如圖片分割任務和圖片分類任務就很相似,因為他們都是對圖片進行處理的任務。

對相似數(shù)據(jù)類型進行處理的任務的模型往往可以互相遷移使用,而不必重新訓練一個新的模型,從而節(jié)省時間和空間的開支

在遷移學習的領域之中,圖片處理的任務往往占據(jù)大多數(shù),因為圖片任務的處理往往都含有相似的部分 —— 提取特征。在實際的任務之中,我們往往會使用已經(jīng)在大型數(shù)據(jù)集(比如 ImageNet )上訓練得到的模型作為遷移學習的基本模型,以此來提取圖片的特征,從而進行下一步的處理

簡單來說就是:使用別人訓練好的模型來做自己的學習任務。

2. 遷移學習的基本思路

遷移學習是一個非常寬泛的概念,其的種類包括很多,我們這里以圖片任務為例來講解遷移學習的基本思路:

  • 選擇遷移學習的基本模型,一般為在大型數(shù)據(jù)集上訓練的大型網(wǎng)絡,比如:
    • ResNet 網(wǎng)絡;
    • GoogLeNet 網(wǎng)絡;
    • Xception 網(wǎng)絡;
  • 然后選擇使用網(wǎng)絡的哪些部分,一般使用除了頂層的所有部分;
  • 編寫剩余的部分,也就是自己接下來的處理過程;
  • 訓練自己編寫的處理過程。

這幾個步驟看起來非常簡單,在實際過程之中也是非常簡單的,接下來我們就以在 ImageNet 超大數(shù)據(jù)集上訓練的 Xception 模型作為基本模型進行遷移學習的演示。

3. 使用遷移學習的實例

這次,我們依然使用貓狗分類的例子來進行實現(xiàn),具體的代碼如下所示:

注意:部分代碼來自 TensorFlow 官方 API 。

import tensorflow_datasets as tfds
import tensorflow as tf
import numpy as np


train_data, validation_data = tfds.load(
    "cats_vs_dogs",
    split=["train[:80%]", "train[80%:]"],
    as_supervised=True,
)

# 重新調整大小
train_data    = train_data.map(lambda x, y: (tf.image.resize(x, (150, 150)), y))
validation_data = validation_data.map(lambda x, y: (tf.image.resize(x, (150, 150)), y))

# 分批次
train_data = train_data.batch(32)
validation_data = validation_data.batch(32)

# 遷移模型
base_model = tf.keras.applications.Xception(
    weights="imagenet",
    input_shape=(150, 150, 3),
    include_top=False,
)

base_model.trainable = False

# 定義輸入
inputs = tf.keras.Input(shape=(150, 150, 3))
# 數(shù)據(jù)正則化
norm_layer = tf.keras.layers.experimental.preprocessing.Normalization()
x = norm_layer(inputs)
mean = np.array([127.5] * 3)
norm_layer.set_weights([mean, mean ** 2])

# 數(shù)據(jù)經(jīng)過遷移模型
x = base_model(x, training=False)
# 數(shù)據(jù)經(jīng)過自定義網(wǎng)絡
x = tf.keras.layers.GlobalAveragePooling2D()(x)
outputs = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs, outputs)

model.summary()

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.BinaryAccuracy()],
)

model.fit(train_ds, epochs=20, validation_data=validation_ds)

在這里的代碼之中,我們有幾處需要注意的地方:

  • 在數(shù)據(jù)獲取方面,我們采用了 tfds.load 函數(shù),該函數(shù)能夠直接獲取相應的內置數(shù)據(jù)集,同時進行相應的分割,這里我們按照 8:2 的比例來進行訓練集、測試集的劃分;
  • 我們使用 map 函數(shù),來將所有的數(shù)據(jù)的圖片重新調整至(150, 150)大小,我們將圖片調整至相同大小是為了方便后面的處理;
  • 使用 tf.keras.applications.Xception API 來獲取已經(jīng)預訓練的 Xception 模型,在該 API 之中,包含三個參數(shù):
    • weights:表示在哪個數(shù)據(jù)集上訓練;
    • input_shape:表示輸入圖片的形狀;
    • include_top=False:表示不含頂層網(wǎng)絡,因為我們要定義自己的網(wǎng)絡。
  • 然后我們使用 base_model.trainable=False 語句來將基本模型的訓練參數(shù)凍結,這樣我們就不能訓練 Xception 的參數(shù)。
  • 我們使用了 tf.keras.layers.experimental.preprocessing.Normalization 這個 API 來進行數(shù)據(jù)的正則化,我們需要通過 norm_layer.set_weights () 設定它的權重:
    • 第一個參數(shù)是輸入的每個通道的平均值,這里是 255/2=127.5;
    • 第二個參數(shù)是第一個參數(shù)的平方;
  • 最后我們采用了一種新的定義模型的方式:先定義一個 Input ,然后將該 Input 逐次經(jīng)過自己需要處理的網(wǎng)絡層得到 output,最后通過 tf.keras.Model (inputs, output) 來讓 TensorFlow s 根據(jù)數(shù)據(jù)的流動過程來自動生成網(wǎng)絡模型。

最終我們可以得到結果:

Model: "functional_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_10 (InputLayer)        [(None, 150, 150, 3)]     0         
_________________________________________________________________
normalization_3 (Normalizati (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d_2 ( (None, 2048)              0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 2048)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 2,049
Non-trainable params: 20,861,487
_________________________________________________________________

Epoch 1/20
291/291 [==============================] - 9s 31ms/step - loss: 0.1607 - binary_accuracy: 0.9313 - val_loss: 0.0872 - val_binary_accuracy: 0.9703
Epoch 2/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1181 - binary_accuracy: 0.9501 - val_loss: 0.0869 - val_binary_accuracy: 0.9690
......
Epoch 20/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0914 - binary_accuracy: 0.9841 - val_loss: 0.0875 - val_binary_accuracy: 0.9765

我們可以看到,我們的模型最終達到了 97% 的分類準確率,這是一個非常高的準確率,而這得益于 Xception 模型強大的特征提取能力。

4. 小結

在這節(jié)課之中,我們學習了什么是遷移學習,同時了解了遷移學習的一般思路,同時我們有手動實現(xiàn)了一個使用遷移學習進行分類的例子。在示例之中,我們學習到了一種新的模型定義方式。

圖片描述