使用 tf.function 提升效率
在之前的入門介紹之中,我們曾經(jīng)介紹過 TensorFlow1.x 采用的并不是 Eager execution 執(zhí)行模型;而 TensorFlow2.x 默認采用的是 Eager execution 模式。
這種改變使得我們可以更加容易地學習,但是也會造成性能的損失,因此, TensorFlow 在 2.0 版本之后引入了 tf.function 。
1. 什么是 tf.funtion
在 TensorFlow1.x 之中,如果我們想要運行一個學習任務(wù),那么我們需要首先創(chuàng)建一個 tf.Sesstion (),然后再調(diào)用 Session.run () 進行運行。
其實在 TensorFlow1.x 內(nèi)部,當我們在 TensorFlow 之中進行工作的時候, TensorFlow 會幫助我們創(chuàng)建一個計算圖 tf.graph ,然后通過 tf.Session 對計算圖進行計算。
而在 TensorFlow2.x 之中,其默認采用的是 Eager execution 執(zhí)行方式,在該執(zhí)行方式之中,我們不再需要定義一個計算圖來進行。
這樣就產(chǎn)生了一些問題:
- 使用 tf.Sesstion () 的運行效率非常高,但是代碼很難懂;
- 使用 Eager execution 方式的代碼很簡單,但是執(zhí)行效率比較低。
有什么方法能夠兼顧兩者嗎?
那就是 tf.function 。
2. tf.funtion 的用法
tf.function 是一個函數(shù)標注修飾,也就是如下的形式:
@tf.function
def my_function():
...
其實如你所見,這就是 tf.function 的全部用法。
我們只需要在我們要修飾的函數(shù)之前加上 tf.function 標注既可。
采用 tf.function,TensorFlow 會將該函數(shù)轉(zhuǎn)變?yōu)橛嬎銏D tf.graph 的形式來進行運算,這會使得該函數(shù)在進行大量運算的時候會加速非常多。
是不是所有的函數(shù)都適合 tf.function 進行修飾呢?
答案是否定的,以下兩種情況不適合使用 tf.function 進行修飾:
- 函數(shù)本身的計算非常簡單,那么構(gòu)建計算圖本身的時間就會相對非常浪費;
- 當我們需要在函數(shù)之中定義 tf.Variable 的時候,因為 tf.function 可能會被調(diào)用多次,因此定義 tf.Variable 會產(chǎn)生重復定義的情況。
3. tf.function 的性能
既然了解了 tf.function 的用法,那么我們便來測試一下 tf.function 的性能,我們采用一個簡單的卷積神經(jīng)網(wǎng)絡(luò)來進行測試:
import tensorflow as tf
import timeit
def f1(layer, image):
y = layer(image)
return y
@tf.function
def f2(layer, image):
y = layer(image)
return y
layer = tf.keras.layers.Conv2D(300, 3)
image = tf.zeros([64, 32, 32, 3])
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')])
print(timeit.timeit(lambda: f1(model, image), number=500))
print(timeit.timeit(lambda: f2(model, image), number=500))
在這里,我們定義了兩個相同的函數(shù),其中一個使用了 tf.function 進行修飾,而另外一個沒有。
在這里我們使用 lambda 函數(shù)來讓函數(shù)重復執(zhí)行 500 次,并且使用 timeit 來進行時間的統(tǒng)計,得到兩個函數(shù)的執(zhí)行時間,從而進行比較。
最終,我們可以得到結(jié)果:
17.20403664399987
12.07886406200032
由此可以看出,我們的 tf.function 已經(jīng)提升了一定的速度,但是提升的速度有限,目前大概提升了 25 % 的速度。這是因為我們的計算仍然還是太簡單了,當我們計算非常大的時候,性能會有很大的提升。
4. 小結(jié)
在這節(jié)課之中,我們學習到了什么是 tf.function ,以及 tf.function 的基本原理,然后我們了解了 tf.function 的使用方法;最后我們通過一個簡單的神經(jīng)網(wǎng)絡(luò)來進行了性能的測試,最終我們發(fā)現(xiàn)我們的 tf.function 確實能給我們性能帶來很大的提升。