在 TensorFlow 之中使用注意力模型
在之前的學(xué)習(xí)之中,我們學(xué)習(xí)了很多的網(wǎng)絡(luò)模型,比如 CNN、RNN 等基本的網(wǎng)絡(luò)模型,雖然這些模型是根據(jù)人的信息處理方式來進(jìn)行設(shè)計并實現(xiàn)的,但是這些模型都有一些特點,那就是只是會根據(jù)輸入的數(shù)據(jù)進(jìn)行機(jī)械地輸出。那么我們這節(jié)課便要來學(xué)習(xí)一下更加 “貼近人的信息處理方式的方法”———— 注意力機(jī)制。
1. 什么是注意力機(jī)制
顧名思義,注意力機(jī)制,“Attention”,就是模仿人的注意力來進(jìn)行網(wǎng)絡(luò)模型的設(shè)計與實現(xiàn)。
我們每個人在日常生活之中,無時無刻不在使用著注意力,比如:
- 我們在看電視的時候會忽略掉電視周圍的環(huán)境;
- 我們在學(xué)習(xí)的時候會對書本的注意力集中度較高;
- 我們在聽音樂的時候?qū)σ魳繁旧淼淖⒁饬^高,反而對周圍的噪音注意力較小。
在神經(jīng)網(wǎng)絡(luò)之中采用注意力可以機(jī)制可以通過模仿人類的注意力行為,來對數(shù)據(jù)之中的重要的細(xì)節(jié)賦予更高的權(quán)重,反而對于一些不重要的細(xì)節(jié)來賦予較低的權(quán)重。
舉個例子,如下圖所示,有一只手拿著一朵花在以堆草叢之前,那么我們?nèi)嗽谟^察這種圖片的時候,一般會將更多的注意力集中在這朵花和這只手上,而不是將注意力放在背景的草叢中。因此我們要讓我們的網(wǎng)絡(luò)模型學(xué)會如何使用注意力機(jī)制,從而其實將注意力更多地放在花和手上。
2. 注意力的分類
注意力按照存在的地方大概可以分為四類:
- 空間注意力,就是我們上述圖片所表述的注意力,它主要是強(qiáng)調(diào)我們在空間之上要注意哪些地方;
- 時間注意力,圖片沒有時間注意力,像音頻、視頻等連續(xù)的數(shù)據(jù)會使用到時間注意力,表示我們在哪個時間段要提高注意力;
- 通道注意力,眾所周知,一般的圖片包含三個通道:R、G、B,那么通道注意力就是強(qiáng)調(diào)在哪個通道之上給予更高的注意力權(quán)重;
- 混合注意力,使用上述兩種及其以上的注意力,從而達(dá)到更好的效果。
在接下來的例子之中,我們會以通道注意力為例子進(jìn)行演示如何使用注意力機(jī)制。
3. 通道上的注意力機(jī)制的實現(xiàn) ——SELayer
SENet 是一個使用通道注意力的模型,它可以對不同的通道求得不同的權(quán)重,進(jìn)而對他們加權(quán),從而實現(xiàn)通道域上的注意力機(jī)制。
SELayer 是 SENet 之中的一個網(wǎng)絡(luò)層,是 SENet 的核心部分,我們可以將其單獨摘出來作為一個通道域上的注意力。
SELayer 的網(wǎng)絡(luò)圖如下圖所示:
在上圖之中,我們可以發(fā)現(xiàn),對于已經(jīng)求得的特征(第二個正方體),SELayer 首先使用卷積網(wǎng)絡(luò),將其變?yōu)?1 * 1 * C 的特征,然后對于該特征進(jìn)行一定的處理,處理結(jié)束之后的每一個通道的一個數(shù)字就代表著原特征圖的相應(yīng)通道的權(quán)重。最后我們將求得的權(quán)重乘到原特征上去便可以得到加權(quán)后的特征,這就表示我們已經(jīng)在通道域上實現(xiàn)了注意力機(jī)制。
在 TensorFlow 之中,我們可以通過繼承 tf.keras.laysers.Layer 類來定義自己的網(wǎng)絡(luò)層,于是我們可以將我們的 SELayer 定義為如下:
class SELayer(tf.keras.Model):
def __init__(self, filters, reduction=16):
super(SELayer, self).__init__()
self.filters = filters
self.reduction = reduction
self.GAP = tf.keras.layers.GlobalAveragePooling2D()
self.FC = tf.keras.models.Sequential([
tf.keras.layers.Dense(units=self.filters // self.reduction, input_shape=(self.filters, )),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dense(units=filters),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid')
])
self.Multiply = tf.keras.layers.Multiply()
def call(self, inputs, training=None, mask=None):
x = self.GAP(inputs)
x = self.FC(x)
x = self.Multiply([x, inputs])
return x
在初始化的函數(shù)之中,我們定義了我們需要用到的網(wǎng)絡(luò)層以及相應(yīng)的結(jié)構(gòu),通過 call 函數(shù)與初始化函數(shù),我們可以得到該層的執(zhí)行方式:
- 首先數(shù)據(jù)會經(jīng)過一個全局平均池化,來變成一個 1* 1 * c 形狀的特征;
- 然后經(jīng)過我們定義的 FC 層,來計算出一個 1 * 1 * c 的權(quán)重,其中 FC 層包括;
- 一個全連接層;
- 一個 DropOut 層用于避免過擬合;
- 一個批次正則化層,這是便于更好地進(jìn)行訓(xùn)練;
- 一個 relu 激活函數(shù);
- 另外一個全連接層;
- 另外一個 DropOut 層;
- 另外一個批次正則化層;
- 一個 sigmoid 激活函數(shù);
- 在得到權(quán)重之后,我們便使用矩陣的乘法,將原來的輸出與權(quán)重相乘,從而得到在最終的結(jié)果。
4. 使用通道注意力機(jī)制的完整代碼
在定義了我們的注意力層之后,我們便可以著手將注意力機(jī)制應(yīng)用到我們之前的任務(wù)之中,在這里我們以以前學(xué)習(xí)過的貓和狗分類為例子,添加我們的 Attention 機(jī)制,并且查看最終的結(jié)果:
import tensorflow as tf
import os
import matplotlib.pyplot as plt
dataset_url = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_download = os.path.dirname(tf.keras.utils.get_file('cats_and_dogs.zip', origin=dataset_url, extract=True))
train_dataset_dir = path_download + '/cats_and_dogs_filtered/train'
valid_dataset_dir = path_download + '/cats_and_dogs_filtered/validation'
BATCH_SIZE = 64
TRAIN_NUM = 2000
VALID_NUM = 1000
EPOCHS = 15
Height = 128
Width = 128
train_image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
valid_image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
train_data_generator = train_image_generator.flow_from_directory(batch_size=BATCH_SIZE,
directory=train_dataset_dir,
shuffle=True,
target_size=(Height, Width),
class_mode='binary')
valid_data_generator = valid_image_generator.flow_from_directory(batch_size=BATCH_SIZE,
directory=valid_dataset_dir,
shuffle=True,
target_size=(Height, Width),
class_mode='binary')
class SELayer(tf.keras.Model):
def __init__(self, filters, reduction=16):
super(SELayer, self).__init__()
self.filters = filters
self.reduction = reduction
self.GAP = tf.keras.layers.GlobalAveragePooling2D()
self.FC = tf.keras.models.Sequential([
tf.keras.layers.Dense(units=self.filters // self.reduction, input_shape=(self.filters, )),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dense(units=filters),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid')
])
self.Multiply = tf.keras.layers.Multiply()
def call(self, inputs, training=None, mask=None):
x = self.GAP(inputs)
x = self.FC(x)
x = self.Multiply([x, inputs])
return x
def build_graph(self, input_shape):
input_shape_without_batch = input_shape[1:]
self.build(input_shape)
inputs = tf.keras.Input(shape=input_shape_without_batch)
_ = self.call(inputs)
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu',
input_shape=(Height, Width ,3)),
tf.keras.layers.MaxPooling2D(),
SELayer(16),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(),
SELayer(32),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(),
SELayer(64),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
model.summary()
history = model.fit_generator(
train_data_generator,
steps_per_epoch=TRAIN_NUM // BATCH_SIZE,
epochs=EPOCHS,
validation_data=valid_data_generator,
validation_steps=VALID_NUM // BATCH_SIZE)
acc = history.history['accuracy']
loss=history.history['loss']
val_acc = history.history['val_accuracy']
val_loss=history.history['val_loss']
epochs_ran = range(EPOCHS)
plt.plot(epochs_ran, acc, label='Train Acc')
plt.plot(epochs_ran, val_acc, label='Valid Acc')
plt.show()
plt.plot(epochs_ran, loss, label='Train Loss')
plt.plot(epochs_ran, val_loss, label='Valid Loss')
plt.show()
通過運行代碼,我們可以得到運行的結(jié)果:
Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.
Model: "sequential_7"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_4 (Conv2D) (None, 128, 128, 16) 448
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 64, 64, 16) 0
_________________________________________________________________
se_layer_3 (SELayer) (None, 64, 64, 16) 117
_________________________________________________________________
dropout_8 (Dropout) (None, 64, 64, 16) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 64, 64, 32) 4640
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 32, 32, 32) 0
_________________________________________________________________
se_layer_4 (SELayer) (None, 32, 32, 32) 298
_________________________________________________________________
dropout_11 (Dropout) (None, 32, 32, 32) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 32, 32, 64) 18496
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 16, 16, 64) 0
_________________________________________________________________
se_layer_5 (SELayer) (None, 16, 16, 64) 852
_________________________________________________________________
dropout_14 (Dropout) (None, 16, 16, 64) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 16384) 0
_________________________________________________________________
dropout_15 (Dropout) (None, 16384) 0
_________________________________________________________________
dense_14 (Dense) (None, 512) 8389120
_________________________________________________________________
dropout_16 (Dropout) (None, 512) 0
_________________________________________________________________
dense_15 (Dense) (None, 1) 513
=================================================================
Total params: 8,414,484
Trainable params: 8,414,246
Non-trainable params: 238
_________________________________________________________________
Epoch 1/15
31/31 [==============================] - 56s 2s/step - loss: 0.7094 - accuracy: 0.5114 - val_loss: 0.6931 - val_accuracy: 0.5310
Epoch 2/15
31/31 [==============================] - 48s 2s/step - loss: 0.6930 - accuracy: 0.4990 - val_loss: 0.6927 - val_accuracy: 0.5869
......
Epoch 14/15
31/31 [==============================] - 54s 2s/step - loss: 0.6174 - accuracy: 0.6348 - val_loss: 0.6309 - val_accuracy: 0.7240
Epoch 15/15
31/31 [==============================] - 47s 2s/step - loss: 0.6030 - accuracy: 0.6446 - val_loss: 0.6195 - val_accuracy: 0.7565
于是我們可以發(fā)現(xiàn),我們的模型最終達(dá)到了 75% 的準(zhǔn)確率,大家可以和之前的模型的結(jié)果做一個比較。
同時大家也可以根據(jù)自己對 CNN 和 MaxPooling 的理解來調(diào)整模型以及相應(yīng)的參數(shù),從而達(dá)到一個更好的效果。
5. 小結(jié)
通過這節(jié)課的學(xué)習(xí),我們了解了什么是注意力機(jī)制,并且了解了注意力的分類(空間、時間、通道、混合),并且手動實現(xiàn)了一個通道域的注意力機(jī)制,并且最后進(jìn)行了實現(xiàn)。