在 TensorBoard 之中查看模型結(jié)構(gòu)圖
在之前的學(xué)習(xí)過程之中,我們學(xué)習(xí)了如何自定義查看訓(xùn)練過程之中的各項指標(biāo)。在實際的應(yīng)用過程之中,為了保證模型構(gòu)建的準(zhǔn)確性,我們也會經(jīng)常查看網(wǎng)絡(luò)的模型結(jié)構(gòu)圖。那么這節(jié)課我們就來看一下如何在 TensorBoard 之中查看模型圖。
1. 如何在 TensorBoard 之中生成 Keras 模型結(jié)構(gòu)圖
倘若我們通過 tf.keras API 來自定義了一個網(wǎng)絡(luò)模型,那么我們在 TensorBoard 來查看模型圖是非常簡單的一件事情。
當(dāng)我們使用 tf.keras 的模型的 fit() 方法的時候,框架會自動幫我們繪制模型結(jié)構(gòu)圖。
如下代碼所示:
首先我們定義模型、數(shù)據(jù)與相應(yīng)的參數(shù)。
import tensorflow as tf
(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=[])
然后我們定義相應(yīng)的 TensorBoard 日志目錄,同時對模型使用 fit() 進(jìn)行訓(xùn)練:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs')
model.fit(x=x_train, y=y_train,
epochs=3,
validation_data=(x_test, y_test),
callbacks=[tensorboard_callback])
最后我們就可以打開 TensorBoard 并在瀏覽器查看:
tensorboard --logdir logs
我們就可以在瀏覽器的 Graph 標(biāo)簽頁之中看到模型圖了:
2. 如何在 TensorBoard 之中生成使用 tf.function 函數(shù)定義的圖
在實際的應(yīng)用過程之中,有很多的情況下,我們需要使用 tf.function 來加速模型的速度并自定義訓(xùn)練過程。那么這個時候我們要如何才能查看網(wǎng)絡(luò)的模型結(jié)構(gòu)圖呢?
其實也很簡單,我們只需經(jīng)過如下幾個步驟:
- 確保 tf.function 函數(shù)修飾了我們需要進(jìn)行可視化的操作,這邊就是模型的過程;
- 創(chuàng)建一個 TensorBoard 的日志寫入器 tf.summary.create_file_writer() ;
- 通過 tf.summary.trace_on() API 進(jìn)行變量路徑的追蹤;
- 執(zhí)行我們需要可視化的操作;
- 使用 tf.summary.trace_export() API 將圖寫入日志。
在這里,我們可以使用一個很簡單的例子來查看操作的結(jié)構(gòu):
# 定義網(wǎng)絡(luò)的操作
@tf.function
def test_func(x, y):
z = tf.matmul(x, y)
z = z * 5.0
z = tf.nn.relu(z)
return z
# 創(chuàng)建寫入器
writer = tf.summary.create_file_writer('./logs/3')
# 創(chuàng)建初試數(shù)據(jù)
x = tf.random.uniform((5, 5))
y = tf.random.uniform((5, 5))
# 開啟變量追蹤
tf.summary.trace_on(graph=True, profiler=True)
# 運行程序
z = test_func(x, y)
# 將日志輸出
with writer.as_default():
tf.summary.trace_export(
name="test_func_graph",
step=1,
profiler_outdir='./logs/3')
在這里,我們首先定義了一個基本的模型操作,該模型操作由一個矩陣乘法、一個常量乘法、外加一個 Relu 激活層組成。
在運行完操作之后,我們便使用 tf.summary.trace_export() API 來將模型圖輸入道日志之中。
然后我們便可以在瀏覽器之中查看到相應(yīng)的模型圖:
可以看到,該模型圖完整的反映了我們的操作。
3. TensorBoard 之中基本、基本的操作
既然了解了如何將模型圖輸出到日志,那么接下來我們就應(yīng)該查看在 TensorBoard 之中對模型圖的基本操作。
3.1 平移、縮放以及詳細(xì)信息的查看
在 TensorBoard 之中,使用鼠標(biāo)滾輪即可實現(xiàn)模型圖的縮放,當(dāng)我們一直放大,會看到操作內(nèi)部的細(xì)節(jié)。
并且按住鼠標(biāo)左鍵,移動鼠標(biāo),即可實現(xiàn)模型圖的移動操作。
雙擊網(wǎng)絡(luò)節(jié)點,即可展開網(wǎng)絡(luò)節(jié)點,從而查看到網(wǎng)絡(luò)內(nèi)部的細(xì)節(jié)操作。
3.2 模型的節(jié)點的搜索
在左側(cè)的最上方,可以搜索自己想要查看的節(jié)點,這里是支持正則表達(dá)式的。
3.3 模型的下載
點擊左側(cè)的 Download PNG 即可下載帶有透明度的、網(wǎng)絡(luò)模型的圖片。
3.4 切換網(wǎng)絡(luò)模型
點擊左側(cè)的 Run 按鈕,即可選擇不同的網(wǎng)絡(luò)模型進(jìn)行查看,前提是我們已經(jīng)將網(wǎng)絡(luò)模型輸入到日志之中去。
3.5 切換查看方式
點擊左側(cè)的 Tag 選項,即可查看網(wǎng)絡(luò)的查看方式。
默認(rèn)是查看中粒度的網(wǎng)絡(luò)模型,如果我們的模型是使用 Keras 定義的,那么我們可以選擇查看 Keras 結(jié)構(gòu),這是一個總體的概覽,可以幫助我們掌握大體的網(wǎng)絡(luò)結(jié)構(gòu)。
3.6 模型圖的圖例
當(dāng)我們遇到一些不理解的圖標(biāo)的時候,我們可以通過左下角的圖例進(jìn)行查詢:
4. 小結(jié)
在這節(jié)課之中,我們學(xué)習(xí)率如何在 TensorBoard 之中查看 Keras 模型,同時也了解了如何產(chǎn)看自定義的操作過程,最后我們了解了 TensorBoard 的一些基本操作。 TensorBoard 也在持續(xù)更新,未來一定會有更多新的功能。