第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

全部開發(fā)者教程

TensorFlow 入門教程

首頁 慕課教程 TensorFlow 入門教程 TensorFlow 入門教程 將Keras模型轉(zhuǎn)化為Estimator模型

將 Keras 模型轉(zhuǎn)化為 Estimator 模型

在上一節(jié)課之中,我們學(xué)習(xí)了使用 Estimator 來進(jìn)行模型訓(xùn)練的一系列步驟,同時也通過一個示例了解了如何使用內(nèi)置的 Estimator 來進(jìn)行模型的訓(xùn)練。

而在實際的應(yīng)用過程之中,我們難免會遇到要使用自定義模型進(jìn)行訓(xùn)練的情況,而在之前的學(xué)習(xí)之中我們曾經(jīng)學(xué)習(xí)過如何使用 Keras 來自定義模型,那么這節(jié)課我們就來學(xué)習(xí)一下如何使用自定義的 Keras 模型轉(zhuǎn)化為 Estimator 模型并進(jìn)行訓(xùn)練。

那么這節(jié)課我們便使用對 Fashion_Mnist 數(shù)據(jù)集進(jìn)行分類的示例來學(xué)習(xí)如何 Keras 模型轉(zhuǎn)化為 Estimator 模型。

1. 采用的方法

在 TensorFlow 之中,若要將一個 Keras 模型轉(zhuǎn)化為一個 Estimator 模型,我們只需要調(diào)用一個步驟,那就是 tf.keras.estimator.model_to_estimator 接口。

可以看出,該 API 是一個 keras 的 API ,而該接口會將一個 Keras 模型轉(zhuǎn)化為一個估算器(Estimator)。

對于該接口的詳細(xì)參數(shù),我們會在具體實現(xiàn)的時候進(jìn)行說明。

由于轉(zhuǎn)化完成的接口仍然是一個 Estimator 模型,因此我們依然需要之前的四步驟:

  • 定義特征列,在這里可以省略,因為輸入的特征我們可以在 Keras 模型之中處理;
  • 定義輸入函數(shù);
  • 創(chuàng)建 Estimator 模型,在這里就是將Keras模型轉(zhuǎn)化為 Estimator 模型;
  • 進(jìn)行訓(xùn)練與評估等操作。

2. 具體示例

首先我們需要獲取數(shù)據(jù)集,和之前一樣,為了方便,我們直接采用官方 API 進(jìn)行數(shù)據(jù)集的獲?。?/p>

import tensorflow as tf

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0

通過上面的代碼,我們便得到了數(shù)據(jù)集,同時對圖片數(shù)據(jù)進(jìn)行了歸一化處理,也就是所將他們的數(shù)值規(guī)范化到 [0, 1] 之間。

然后我們就需要定義我們的模型:

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(256, activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
        optimizer='adam', metrics=['accuracy'])

在這里我們使用了一個和之前一樣的模型:

  • 第一層為一個扁平化層,將二維數(shù)據(jù)一維化;
  • 第二層為一個隱含層,其中包括256個隱含節(jié)點;
  • 最后一層為輸出層,包含十個節(jié)點,因為我們的分類一共有10個。

然后我們對模型進(jìn)行了編譯,我們采用的是 adam 優(yōu)化器,因為我們是多分類任務(wù),因此我們采用了 SparseCategoricalCrossentropy 損失函數(shù)。

然后我們需要進(jìn)行非常重要的一步:定義輸入函數(shù),在這里我們要定義兩個輸入函數(shù),第一個為訓(xùn)練時的輸入函數(shù),另一個為測試時的輸入函數(shù)。

def train_input_fn(x_train, y_train):
  dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
  dataset = dataset.batch(16).repeat()
  return dataset

def test_input_fn(x_test, y_test):
  dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
  dataset = dataset.batch(16)
  return dataset

在這里我們依然采用的是 tf.data.Dataset.from_tensor_slices 切片的方法來構(gòu)建數(shù)據(jù)集,同時我們進(jìn)行了分批處理,其中批次的大小為16,而對于訓(xùn)練集,我們進(jìn)行了重復(fù)處理,因為我們的訓(xùn)練過程可能需要多次遍歷數(shù)據(jù)集。

然后我們便將我們創(chuàng)建的Keras模型轉(zhuǎn)化為Estimator模型:

keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model)

僅僅需要一行代碼,我們的Keras模型就轉(zhuǎn)化為了 Estimator 模型,在這里我們僅僅使用了一個參數(shù),那就是 keras_model 參數(shù),其實該函數(shù)還包括幾個非常有用的參數(shù):

  • keras_model_path:與 keras_model 互斥,也就是說不能與 keras_model 同時不為 None ,用來指定從磁盤上加載一個模型;
  • model_dir:用來保存各種日志、參數(shù)、圖表的目錄;
  • checkpoint_format:用來保存模型的格式。

最后我們進(jìn)行模型的訓(xùn)練與測試:

keras_estimator.train(input_fn=lambda: train_input_fn(x_train, y_train), steps=5000)
eval_result = keras_estimator.evaluate(input_fn=lambda: test_input_fn(x_test, y_test), steps=len(x_test)//16)
print(eval_result)

在這里,我們在訓(xùn)練的時候采用了一個參數(shù) steps ,這個參數(shù)指明要訓(xùn)練多少個 batch_size (在這里是16 );因為在測試的時候我們只需要進(jìn)行一次數(shù)據(jù)遍歷,因此我們的 steps 設(shè)置為 len(x_test)//16 。

舉個例子,加入我們數(shù)據(jù)一共有 10 條,我們想要訓(xùn)練 20 個 epoch ,那么我們就需要設(shè)置 steps 為10*20=200 。而如果 steps 為 300 ,那我們的 epoch 就是 300 // 10 = 30 。

而最后我們得到的 eval_result 是一個字典類型的對象,其中包含了 loss、step 以及我們定義的 metrics 等字段。

最后在測試結(jié)束后我們可以得到結(jié)果為:

INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.8086, global_step = 5000, loss = 1.652609
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmp/tmptu8apzc8/model.ckpt-5000
{'accuracy': 0.8086, 'loss': 1.652609, 'global_step': 5000}

可以看到我們的模型最終達(dá)到了 0.8 的準(zhǔn)確率,如果提升 steps ,我們便可以在一定程度上進(jìn)一步提升準(zhǔn)確率。

3. 小結(jié)

在這節(jié)課之中,我們學(xué)習(xí)了如何將 Keras 轉(zhuǎn)化為 Estimator 模型,總而言之,我們依然需要使用訓(xùn)練 Estimator 模型的一貫步驟。同時我們也了解了 tf.keras.estimator.model_to_estimator 的幾個常用的參數(shù)。

圖片描述