在 TensorFlow 之中進(jìn)行數(shù)據(jù)增強(qiáng)
在我們之前的學(xué)習(xí)之中,我們所使用的數(shù)據(jù)都是進(jìn)行一些 “簡(jiǎn)單的處理”,比如正則化、歸一化、分批次等基本操作;這些操作都有一些特點(diǎn),那就是在固定的數(shù)據(jù)集上進(jìn)行處理,也就是說(shuō)這些處理并不會(huì)改變數(shù)據(jù)的數(shù)量(甚至可能會(huì)減少數(shù)據(jù)的數(shù)量,比如數(shù)據(jù)篩選)。
那么這節(jié)課我們便來(lái)學(xué)習(xí)一下如何在 TensorFlow 之中數(shù)據(jù)增強(qiáng),它可以增加數(shù)據(jù)量,從而可以使用更多樣的數(shù)據(jù)來(lái)訓(xùn)練模型。
1. 什么是數(shù)據(jù)增強(qiáng)
關(guān)于數(shù)據(jù)增強(qiáng),我們可以在 TensorFlow API 之中看到相關(guān)的定義:
A technique to increase the diversity of your training set by applying random (but realistic) transformations.
翻譯一下就是:
數(shù)據(jù)增強(qiáng)是一種通過(guò)應(yīng)用隨機(jī)(但現(xiàn)實(shí))的變換來(lái)增加訓(xùn)練集的多樣性的技術(shù)。
簡(jiǎn)單來(lái)說(shuō),通過(guò)數(shù)據(jù)增強(qiáng),我們可以將一些已經(jīng)存在的數(shù)據(jù)進(jìn)行相應(yīng)的變換(可以選擇將這些變換之后的數(shù)據(jù)增加到新的原來(lái)的數(shù)據(jù)集之中,也可以直接在原來(lái)的數(shù)據(jù)集上進(jìn)行變換),從而實(shí)現(xiàn)數(shù)據(jù)種類多樣性的增加。
數(shù)據(jù)增強(qiáng)常見于圖像領(lǐng)域,因此這節(jié)課我們會(huì)以圖像處理為例來(lái)解釋如何在 TensorFlow 之中進(jìn)行數(shù)據(jù)增強(qiáng)。
對(duì)于圖片數(shù)據(jù),常見的數(shù)據(jù)增強(qiáng)方式包括:
- 隨機(jī)水平翻轉(zhuǎn):
- 隨機(jī)的裁剪;
- 隨機(jī)調(diào)整明亮程度;
- 其他方式等。
2. 如何在 TensorFlow 之中進(jìn)行圖像數(shù)據(jù)增強(qiáng)
在 TensorFlow 之中進(jìn)行圖像數(shù)據(jù)增強(qiáng)的方式主要有兩種:
- 使用 tf.keras 的預(yù)處理層進(jìn)行圖像數(shù)據(jù)增強(qiáng);
- 使用 tf.image 進(jìn)行數(shù)據(jù)增強(qiáng)。
這兩種各有不同的特點(diǎn),但是因?yàn)槲覀円捎?tf.keras 進(jìn)行模型的構(gòu)建,因此我們重點(diǎn)學(xué)習(xí)如何使用 tf.keras 的預(yù)處理層進(jìn)行圖像數(shù)據(jù)增強(qiáng)。
1. 如何使用 tf.keras 的預(yù)處理層進(jìn)行圖像數(shù)據(jù)增強(qiáng)
使用 tf.keras 的預(yù)處理層進(jìn)行圖像數(shù)據(jù)增強(qiáng)要使用的最主要的 API 包括在一下包之中:
tf.keras.layers.experimental.preprocessing
在這個(gè)包之中,我們最常用的數(shù)據(jù)增強(qiáng) API 包括:
- tf.keras.layers.experimental.preprocessing.RandomFlip(mode): 將輸入的圖片進(jìn)行隨機(jī)翻轉(zhuǎn),一般我們會(huì)取 mode=“horizontal” ,因?yàn)檫@代表水平旋轉(zhuǎn);而 mode=“vertical” 則代表隨機(jī)進(jìn)行上下翻轉(zhuǎn);
- tf.keras.layers.experimental.preprocessing.RandomRotation§: 按照旋轉(zhuǎn)角度(單位為弧度) p 將輸入的圖片進(jìn)行隨機(jī)的旋轉(zhuǎn);
- tf.keras.layers.experimental.preprocessing.RandomContrast§:按照 P 的概率將輸入的圖片進(jìn)行隨機(jī)的圖像色相翻轉(zhuǎn);
- tf.keras.layers.experimental.preprocessing.CenterCrop(height, width):使用 height * width 的大小的裁剪框,在數(shù)據(jù)的中心進(jìn)行裁剪。
以上介紹的是我們?cè)跀?shù)據(jù)增強(qiáng)處理之中使用的最多的增強(qiáng)方式,在接下來(lái)的學(xué)習(xí)之中,我們會(huì)以該方式為例進(jìn)行程序的演示。
在使用的過(guò)程之中,我們只需要將這些數(shù)據(jù)增強(qiáng)的網(wǎng)絡(luò)層添加到網(wǎng)絡(luò)的最底層即可。
2. 使用 tf.image 進(jìn)行數(shù)據(jù)增強(qiáng)
使用 tf.image 是 TensorFlow 最原生的一種增強(qiáng)方式,使用這種方式可以實(shí)現(xiàn)更多、更加個(gè)性化的數(shù)據(jù)增強(qiáng)。
其中包含的數(shù)據(jù)增強(qiáng)方式主要包括:
- tf.image.flip_left_right (img):將圖片進(jìn)行水平翻轉(zhuǎn);
- tf.image.rgb_to_grayscale (img):將 RGB 圖像轉(zhuǎn)化為灰度圖像;
- tf.image.adjust_saturation (image, f):將 image 圖像按照 f 參數(shù)進(jìn)行飽和度的調(diào)節(jié);
- tf.image.adjust_brightness (image, f):將 image 圖像按照 f 參數(shù)進(jìn)行亮度的調(diào)節(jié);
- tf.image.central_crop (image, central_fraction):按照 p 的比例進(jìn)行圖片的中心裁剪,比如如果 p 是 0.5 ,那么裁剪后的長(zhǎng)、寬就是原來(lái)圖像的一半;
- tf.image.rot90 (image):將 image 圖像逆時(shí)針旋轉(zhuǎn) 90 度。
可以看到,很多的 tf.image 數(shù)據(jù)增強(qiáng)方式并不提供隨機(jī)化選項(xiàng),因此我們需要手動(dòng)進(jìn)行隨機(jī)化。
也正是因?yàn)樯鲜鎏匦裕瑃f.image 數(shù)據(jù)增強(qiáng)主要用在一些自定義的模型之中,從而可以實(shí)現(xiàn)數(shù)據(jù)增強(qiáng)的自定義化。
3. 使用 tf.keras 的預(yù)處理層進(jìn)行數(shù)據(jù)增強(qiáng)的實(shí)例
在這里,我們?nèi)匀徊捎梦覀兪煜さ呢埞贩诸惖睦觼?lái)進(jìn)行程序的演示,我們的代碼和之前的代碼相同,只是我們新增加了兩個(gè)數(shù)據(jù)增強(qiáng)的處理層:
tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical",
input_shape=(Height, Width ,3)),
tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
其中第一個(gè)層表示進(jìn)行隨機(jī)的水平和垂直翻轉(zhuǎn),而第二個(gè)層表示按照 0.2 的弧度值進(jìn)行隨機(jī)旋轉(zhuǎn)。
整體的網(wǎng)絡(luò)程序?yàn)椋?/p>
import tensorflow as tf
import os
import matplotlib.pyplot as plt
dataset_url = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_download = os.path.dirname(tf.keras.utils.get_file('cats_and_dogs.zip', origin=dataset_url, extract=True))
train_dataset_dir = path_download + '/cats_and_dogs_filtered/train'
valid_dataset_dir = path_download + '/cats_and_dogs_filtered/validation'
BATCH_SIZE = 64
TRAIN_NUM = 2000
VALID_NUM = 1000
EPOCHS = 15
Height = 128
Width = 128
train_image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
valid_image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
train_data_generator = train_image_generator.flow_from_directory(batch_size=BATCH_SIZE,
directory=train_dataset_dir,
shuffle=True,
target_size=(Height, Width),
class_mode='binary')
valid_data_generator = valid_image_generator.flow_from_directory(batch_size=BATCH_SIZE,
directory=valid_dataset_dir,
shuffle=True,
target_size=(Height, Width),
class_mode='binary')
model = tf.keras.models.Sequential([
tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical",
input_shape=(Height, Width ,3)),
tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
tf.keras.layers.Conv2D(16, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
model.summary()
history = model.fit_generator(
train_data_generator,
steps_per_epoch=TRAIN_NUM // BATCH_SIZE,
epochs=EPOCHS,
validation_data=valid_data_generator,
validation_steps=VALID_NUM // BATCH_SIZE)
acc = history.history['accuracy']
loss=history.history['loss']
val_acc = history.history['val_accuracy']
val_loss=history.history['val_loss']
epochs_ran = range(EPOCHS)
plt.plot(epochs_ran, acc, label='Train Acc')
plt.plot(epochs_ran, val_acc, label='Valid Acc')
plt.show()
plt.plot(epochs_ran, loss, label='Train Loss')
plt.plot(epochs_ran, val_loss, label='Valid Loss')
plt.show()
在訓(xùn)練結(jié)束后,我們可以得到如下結(jié)果,而這個(gè)結(jié)果與我們之前的結(jié)果有了一個(gè)良好的提升,最高達(dá)到了 79% 的準(zhǔn)確率,因此我們認(rèn)為我們的數(shù)據(jù)增強(qiáng)起到了一定的作用。
Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
random_flip_1 (RandomFlip) (None, 128, 128, 3) 0
_________________________________________________________________
random_rotation_1 (RandomRot (None, 128, 128, 3) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 126, 126, 16) 448
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 63, 63, 16) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 61, 61, 32) 4640
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 30, 30, 32) 0
_________________________________________________________________
conv2d_8 (Conv2D) (None, 28, 28, 64) 18496
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 14, 14, 64) 0
_________________________________________________________________
flatten_2 (Flatten) (None, 12544) 0
_________________________________________________________________
dense_4 (Dense) (None, 512) 6423040
_________________________________________________________________
dense_5 (Dense) (None, 1) 513
=================================================================
Total params: 6,447,137
Trainable params: 6,447,137
Non-trainable params: 0
_________________________________________________________________
Epoch 1/15
31/31 [==============================] - 40s 1s/step - loss: 0.7372 - accuracy: 0.5052 - val_loss: 0.6700 - val_accuracy: 0.5583
......
Epoch 11/15
31/31 [==============================] - 41s 1s/step - loss: 0.5219 - accuracy: 0.8213 - val_loss: 0.5480 - val_accuracy: 0.7900
......
同時(shí)我們的程序還會(huì)輸出以下兩個(gè)圖片:
準(zhǔn)確率變化曲線:
損失變化曲線:
4. 小結(jié)
通過(guò)這節(jié)課的學(xué)習(xí),我們了解到了什么是數(shù)據(jù)增強(qiáng),同時(shí)也明白了如何在 TensorFlow 之中進(jìn)行數(shù)據(jù)增強(qiáng)(兩種不同的實(shí)現(xiàn)方式)。最后我們會(huì)很據(jù)以前的程序進(jìn)行改進(jìn),得到了一個(gè)完整的程序示例。