第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

全部開發(fā)者教程

TensorFlow 入門教程

首頁 慕課教程 TensorFlow 入門教程 TensorFlow 入門教程 TensorFlow 中的回調(diào)函數(shù)

TensorFlow 中的回調(diào)函數(shù)

回調(diào)函數(shù)是 TensorFlow 訓(xùn)練之中非常重要的一部分,我們在之前的學(xué)習(xí)之中或多或少地用到了回調(diào)函數(shù)。比如在之前的過擬合一節(jié)之中,我們就曾經(jīng)用到了早?;卣{(diào)。那么這節(jié)課我們就來學(xué)習(xí)以下 TensorFlow 之中的回調(diào)函數(shù)。

1. 什么是回調(diào)函數(shù)

簡單來說,回調(diào)函數(shù)就是在訓(xùn)練到一定階段的時(shí)候而執(zhí)行的函數(shù),我們最常采用的策略是每個(gè)Epoch結(jié)束之后執(zhí)行一次回調(diào)函數(shù)

回調(diào)函數(shù)的絕大多數(shù) API 集中在 tf.keras.callbacks 之中,也就是說這是 Keras 之中的一個(gè) API 。由于之前已經(jīng)學(xué)習(xí)過早?;卣{(diào),這節(jié)課我們來學(xué)習(xí)一下其他的幾個(gè)常用的回調(diào):

  • 模型保存回調(diào):tf.keras.callbacks.ModelCheckpoint;
  • 學(xué)習(xí)率回調(diào);tf.keras.callbacks.LearningRateScheduler;
  • 自定義回調(diào):tf.keras.callbacks.CallBack。

對于回調(diào)的使用方法,也是非常簡單的,假設(shè)以下的數(shù)組之中定義了我們所需要的全部回調(diào)函數(shù):

callbacks = [......]

那么我們在使用回調(diào)的時(shí)候,之中只需要在訓(xùn)練函數(shù)中指定回調(diào)即可:

model.fit(..., ..., callbacks=callbacks)

對于要介紹的回調(diào),我們會(huì)首先給出介紹,然后再在統(tǒng)一的代碼之中示例使用。

2. 模型保存回調(diào)

模型保存的回調(diào)函數(shù)為:

tf.keras.callbacks.ModelCheckpoint(
    path, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, save_freq='epoch')

這里只列出來了我們常用的參數(shù),對于其中的每個(gè)參數(shù),它們的作用如下:

  • path: 保存模型的路徑;
  • monitor: 用哪個(gè)指標(biāo)來評價(jià)模型的好壞,默認(rèn)是驗(yàn)證集上的損失;
  • verbose: 輸出日志的等級,只能為 0 或 1;
  • save_best_only: 是否只保存最好的模型,模型的好壞由 monitor 指定;
  • save_weights_only: 是否只保存權(quán)重,默認(rèn) False ,也就是保存整個(gè)模型;
  • save_freq: 保存的頻率,可以為 ‘Epoch’ 或者一個(gè)整數(shù),默認(rèn)為每個(gè) Epoch 保存一次模型;若是一個(gè)整數(shù)N,則是每訓(xùn)練 N 個(gè) Batch 保存一次模型。

3. 學(xué)習(xí)率回調(diào)

學(xué)習(xí)率回調(diào)函數(shù)為:

tf.keras.callbacks.LearningRateScheduler(
    schedule, verbose=0
)

其中 verbose 參數(shù)仍然是日志輸出的等級,默認(rèn)為 0 ;而 schedule 則是一個(gè)函數(shù),用來定義一個(gè)學(xué)習(xí)率的變化。其中 schedule 函數(shù)的一個(gè)示例如下所示:

def my_schedule(epoch, lr):
  if epoch < 20:
    return lr
  else:
    return lr * 0.1

該學(xué)習(xí)率回調(diào)是在 20 個(gè) Epoch 之前學(xué)習(xí)率保持不變,而在 20 個(gè) Epoch 之后,每個(gè) Epoch 學(xué)習(xí)率變?yōu)樵瓉淼?0.1 。

可以看出,該 schedule 函數(shù)由嚴(yán)格的形式,其中第一個(gè)參數(shù)為訓(xùn)練的 Epoch ,第二個(gè)參數(shù)為當(dāng)前的學(xué)習(xí)率。

4. 自定義回調(diào)

我們在使用回調(diào)的過程之中難免會(huì)遇到要自定義回調(diào)的情況,這時(shí)我們便需要編寫類來繼承 tf.keras.callbacks.CallBack 類,從而實(shí)現(xiàn)我們的自定義回調(diào)。

在自定義回調(diào)的過程之中,你可以覆寫不同的函數(shù),從而可以實(shí)現(xiàn)在不同的時(shí)間來運(yùn)行我們自定義的函數(shù),這些函數(shù)包括:

  • on_train_begin(self, logs=None): 在訓(xùn)練開始時(shí)調(diào)用;
  • on_test_begin(self, logs=None): 在測試開始時(shí)調(diào)用;
  • on_predict_begin(self, logs=None): 在預(yù)測開始時(shí)調(diào)用;
  • on_train_end(self, logs=None) 在訓(xùn)練結(jié)束時(shí)調(diào)用;
  • on_test_end(self, logs=None) 在測試結(jié)束時(shí)調(diào)用;
  • on_predict_end(self, logs=None) 在預(yù)測結(jié)束時(shí)調(diào)用;
  • on_train_batch_begin(self, batch, logs=None) 在訓(xùn)練期間的每個(gè)批次之前調(diào)用;
  • on_test_batch_begin(self, batch, logs=None) 在測試期間的每個(gè)批次之前調(diào)用;
  • on_predict_batch_begin(self, batch, logs=None) 在預(yù)測期間的每個(gè)批次之前調(diào)用;
  • on_train_batch_end(self, batch, logs=None) 在訓(xùn)練期間的每個(gè)批次之后調(diào)用;
  • on_test_batch_end(self, batch, logs=None) 在測試期間的每個(gè)批次之后調(diào)用;
  • on_predict_batch_end(self, batch, logs=None) 在預(yù)測期間的每個(gè)批次之后調(diào)用;
  • on_epoch_begin(self, epoch, logs=None) 在每次迭代訓(xùn)練開始時(shí)調(diào)用;
  • on_epoch_end(self, epoch, logs=None) 在每次迭代訓(xùn)練結(jié)束時(shí)調(diào)用。

我們可以來使用其中兩個(gè)簡單的函數(shù)來做一個(gè)簡單的示例:

class MyCallback(tf.keras.callbacks.Callback):

    def on_epoch_begin(self, epoch, logs=None):
        print("Start epoch {}.".format(epoch))

    def on_train_begin(self, logs=None):
        print("Starting training.")

這個(gè)樣子,我們便可以在每次訓(xùn)練開始,以及每個(gè) Epoch 開始之時(shí)進(jìn)行輸出日志。

5. 程序示例

在這里,我們將同時(shí)使用模型保存回調(diào)、學(xué)習(xí)率回調(diào)以及自定義回調(diào)來做一個(gè)簡單的示例:

import tensorflow as tf

model = tf.keras.Sequential([
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
])

lr = 0.01

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
    loss="mse"
)

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()


def my_schedule(epoch, lr):
  print('Learning rate: ' + str(lr))
  if epoch < 5:
    return lr
  else:
    return lr * 0.1

lr_callback = tf.keras.callbacks.LearningRateScheduler(my_schedule)

save_model_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='/model/', save_weights_only=True, verbose=1,
    monitor='val_loss', mode='min', save_best_only=True)

class MyCallback(tf.keras.callbacks.Callback):

    def on_epoch_begin(self, epoch, logs=None):
        print("Start epoch {}.".format(epoch))

    def on_train_begin(self, logs=None):
        print("Starting training.")

model.fit(x_train, y_train,
    batch_size=64, epochs=10,
    validation_data=(x_test, y_test),
    callbacks=[MyCallback(), lr_callback, save_model_callback],
)

在這里,我們按照之前學(xué)習(xí)的方法定義了三個(gè)回調(diào)函數(shù),分別是模型保存回調(diào)、學(xué)習(xí)率回調(diào)、以及自定義回調(diào)。其中模型保存回調(diào)會(huì)在每次訓(xùn)練后保存模型、學(xué)習(xí)率回調(diào)會(huì)在第五個(gè) Epoch 之后便每個(gè) Epoch 變?yōu)樵瓉淼?0.1 ,而自定義回調(diào)會(huì)在訓(xùn)練開始之前、每個(gè) Epoch 開始之前輸出相應(yīng)的信息。

于是我們可以得到輸出:

Starting training.
Start epoch 0.
Learning rate: 0.009999999776482582
Epoch 1/10
931/938 [============================>.] - ETA: 0s - loss: 556.1402
Epoch 00001: val_loss improved from inf to 15.96259, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 552.3954 - val_loss: 15.9626
Start epoch 1.
Learning rate: 0.009999999776482582
Epoch 2/10
927/938 [============================>.] - ETA: 0s - loss: 12.4227
Epoch 00002: val_loss improved from 15.96259 to 10.01533, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 12.3927 - val_loss: 10.0153
Start epoch 2.
Learning rate: 0.009999999776482582
Epoch 3/10
914/938 [============================>.] - ETA: 0s - loss: 9.0919
Epoch 00003: val_loss improved from 10.01533 to 8.50834, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 9.0744 - val_loss: 8.5083
Start epoch 3.
Learning rate: 0.009999999776482582
Epoch 4/10
913/938 [============================>.] - ETA: 0s - loss: 8.3514
Epoch 00004: val_loss improved from 8.50834 to 8.26637, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.3450 - val_loss: 8.2664
Start epoch 4.
Learning rate: 0.009999999776482582
Epoch 5/10
920/938 [============================>.] - ETA: 0s - loss: 8.2481
Epoch 00005: val_loss improved from 8.26637 to 8.25048, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2544 - val_loss: 8.2505
Start epoch 5.
Learning rate: 0.009999999776482582
Epoch 6/10
933/938 [============================>.] - ETA: 0s - loss: 8.2504
Epoch 00006: val_loss improved from 8.25048 to 8.25035, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2502 - val_loss: 8.2504
Start epoch 6.
Learning rate: 0.0009999999310821295
Epoch 7/10
932/938 [============================>.] - ETA: 0s - loss: 8.2509
Epoch 00007: val_loss improved from 8.25035 to 8.25034, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
Start epoch 7.
Learning rate: 9.99999901978299e-05
Epoch 8/10
916/938 [============================>.] - ETA: 0s - loss: 8.2600
Epoch 00008: val_loss improved from 8.25034 to 8.25034, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
Start epoch 8.
Learning rate: 9.99999883788405e-06
Epoch 9/10
914/938 [============================>.] - ETA: 0s - loss: 8.2541
Epoch 00009: val_loss did not improve from 8.25034
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
Start epoch 9.
Learning rate: 9.99999883788405e-07
Epoch 10/10
925/938 [============================>.] - ETA: 0s - loss: 8.2446
Epoch 00010: val_loss did not improve from 8.25034
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
<tensorflow.python.keras.callbacks.History at 0x7eff7317f748>

可以看到,我們的三個(gè)回調(diào)函數(shù)都能正確地輸出相應(yīng)的信息,說明我們的回調(diào)函數(shù)已經(jīng)成功生效。

6. 小結(jié)

在這節(jié)課之中,我們學(xué)習(xí)了什么是回調(diào)函數(shù)、模型保存回調(diào)、學(xué)習(xí)率回調(diào)以及如何自定義回調(diào)。同時(shí)我們又通過相應(yīng)的示例演示了如何使用回調(diào)。