使用 Keras 進(jìn)行文本分類
上節(jié)課我們學(xué)習(xí)了如何進(jìn)行圖片分類,在此過程之中我們學(xué)習(xí)到了如何對(duì)圖片數(shù)據(jù)進(jìn)行處理;而對(duì)于文本數(shù)據(jù)我們應(yīng)該如何處理與訓(xùn)練呢?與圖片數(shù)據(jù)相比,文本數(shù)據(jù)有以下幾個(gè)特點(diǎn):
- 長度不確定;
- 語言之間的差異較大,編碼方式各不相同;
- 同一種語言的處理方式也不盡相同;
- 特征提取方式不統(tǒng)一。
因?yàn)槲谋緮?shù)據(jù)的不確定性,因此我們這節(jié)課采用最常用的數(shù)據(jù)處理方式(單詞嵌入)與最常用的文本分類數(shù)據(jù)集( IMBD? 評(píng)價(jià)數(shù)據(jù)集)。
1. 數(shù)據(jù)集合概覽
IMDB? 數(shù)據(jù)集合一共包含 50000 條數(shù)據(jù),每條數(shù)據(jù)都是從 IMDB? 電影的評(píng)價(jià)中選取,同時(shí)每個(gè)評(píng)論都被歸類為**“正面評(píng)價(jià)”或“負(fù)面評(píng)價(jià)”**。比如:
x: [1, 778, 128, 74, 12, 630, 163, 15, 4, 1766, 7982, 1051, 2, 32, 85, 156, 45, 40, 148, 139, 121, 664, 665, 10, 10, 1361, 173, 4, 749, 2, 16, 3804, 8, 4, 226, 65, 12, 43, 127, 24, 2, 10, 10]
y: 0
其中評(píng)論是被編碼之后所得到的數(shù)組,每個(gè)英文單詞對(duì)應(yīng)一個(gè)固定的數(shù)字。而標(biāo)簽用 0 和 1 來表示“負(fù)面評(píng)價(jià)”和“證明評(píng)價(jià)”。
將上述例子還原一下就是:
x: "begins better than it ends funny that the russian submarine crew <UNK> all other actors it's like those scenes where documentary shots br br spoiler part the message <UNK> was contrary to the whole story it just does not <UNK> br br"
y: "Negative"
這 50000 條數(shù)據(jù)它們具體的分布如下:
- 訓(xùn)練集包含 25000 條訓(xùn)練數(shù)據(jù),其中正負(fù)數(shù)據(jù)各 12500 條;
- 測(cè)試集包含 25000 條測(cè)試數(shù)據(jù),其中正負(fù)數(shù)據(jù)各 12500 條。
換句話說,該數(shù)據(jù)集合上面的數(shù)據(jù)是**“平衡的”**,因?yàn)樗恼龢颖九c負(fù)樣本的數(shù)目相同。
在 TensorFlow 之中,我們可以直接通過調(diào)用內(nèi)部 API 的方式來獲取該數(shù)據(jù)集:
(train_data, train_labels), (test_data, test_labels) = \
tf.keras.datasets.imdb.load_data(num_words=words_num)
2. 如何對(duì)文本數(shù)據(jù)進(jìn)行處理
在機(jī)器學(xué)習(xí)之中,我們對(duì)于文本數(shù)據(jù)的處理大致分為以下幾步:
- 數(shù)據(jù)清洗,清理掉無用的數(shù)據(jù);
- 文本編碼,將每一個(gè)單詞轉(zhuǎn)化為一個(gè)數(shù)字來表示;
- 將編碼后的文本轉(zhuǎn)化為定長表示;
- 將文本提取為特征向量進(jìn)行下一步的訓(xùn)練。
其中在這個(gè)例子之中,我們加載的數(shù)據(jù)集合已經(jīng)由 TensorFlow 進(jìn)行過數(shù)據(jù)清洗與文本編碼了,因此我們只需要將其轉(zhuǎn)化為定長表示并且提取其特征向量即可。
2.1 如何將文本數(shù)組填充到定長
在 TensorFlow 之中我們可以采用預(yù)處理的方式來將編碼后的文本轉(zhuǎn)化為定長:
train_data = tf.keras.preprocessing.sequence.pad_sequences(
train_data,
value=0,
padding='post',
maxlen=10
)
其中的各個(gè)參數(shù)的解釋如下:
- trian_data:我們要處理的、編碼后的數(shù)據(jù);
- maxlen:將每個(gè)文本樣本處理后的長度,如果原長度不足 maxlen ,那么便會(huì)使用 value 進(jìn)行填充;如果原長度超過了 maxlen ,那么便會(huì)將文本截?cái)啵?/li>
- value:用來填充文本的數(shù)字,一般我們使用0即可;
- padding:填充的模式,post 表示填充的 value 位置在原文之后。
我們舉個(gè)簡單的例子,如果處理前的文本數(shù)組為:
[1, 2, 3]
當(dāng)我們使用上述方式填充之后的數(shù)據(jù)就會(huì)變?yōu)椋?/p>
[1, 2, 3, 0, 0 ,0, 0, 0, 0, 0]
2.2 如何將文本數(shù)組進(jìn)行嵌入并提取特征向向量
在 TensorFlow 之中,我們最常用的提取文本特征的網(wǎng)絡(luò)層是:
tf.keras.layers.Embedding(vocab_size, dim),
其中 vocab_size 表示的是詞匯量的總數(shù),dim 表示特征向量的維度。
通過輸入編碼后的文本數(shù)組,我們可以得到該文本的特征向量(embedding vector)。
3. 模型的完整表示
當(dāng)我們知道了如何對(duì)文本數(shù)據(jù)進(jìn)行處理之后,我們便可以編寫我們的文本分類模型的程序了。
具體的程序如下:
import tensorflow as tf
import numpy as np
# 定義基本參數(shù)
words_num = 10000
val_num = 12500
EPOCHS = 30
pad_max_length = 256
BATCH_SIZE = 64
# 獲取數(shù)據(jù)
(train_data, train_labels), (test_data, test_labels) = tf.keras.datasets.imdb.load_data(num_words=words_num)
word_index = tf.keras.datasets.imdb.get_word_index()
# 添加特殊字符
word_index = {k:(v+3) for k,v in word_index.items()}
word_index["<pad>"] = 0
word_index["<start>"] = 1
word_index["<unknown>"] = 2
word_index["<unused>"] = 3
# 數(shù)據(jù)預(yù)處理
train_data = tf.keras.preprocessing.sequence.pad_sequences(train_data, value=0, padding='post', maxlen=pad_max_length)
test_data = tf.keras.preprocessing.sequence.pad_sequences(test_data, value=0, padding='post', maxlen=pad_max_length)
# 劃分訓(xùn)練集合與驗(yàn)證集合
x_val, x_train = train_data[:val_num], train_data[val_num:]
y_val, y_train = train_labels[:val_num], train_labels[val_num:]
# 模型構(gòu)建
model = tf.keras.Sequential([
tf.keras.layers.Embedding(words_num, 32),
tf.keras.layers.GlobalAveragePooling1D(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.summary()
# 編譯模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# 訓(xùn)練
history = model.fit(x_train, y_train, epochs=EPOCHS,
batch_size=BATCH_SIZE, validation_data=(x_val, y_val))
# 測(cè)試
results = model.evaluate(test_data, test_labels)
print(results)
在該程序之中有幾個(gè)需要注意的地方:
- 在添加特殊字符字符處我們添加了四個(gè)特殊字符,其中
- 0 表示填充所使用的字符;
- 1 表示句子的開始;
- 2 表示未知單詞,因?yàn)槲覀円?guī)定只使用 10000 個(gè)最常用的單詞;
- 3 表示未使用的單詞。
- 在劃分驗(yàn)證集合的時(shí)候,我們按照 50% 的比例劃分訓(xùn)練集合與驗(yàn)證集合;
- 在模型的第二層,我們采用了一維全局池化,該層沒有可訓(xùn)練的參數(shù),該層是為了降低訓(xùn)練所需要數(shù)據(jù)量,輸出是一個(gè)固定長度的向量;
- 模型的最后一層的激活函數(shù)為 “Sigmoid” ,這個(gè)激活函數(shù)將輸出分為 0 或者 1 ,通常用于二分類的任務(wù)。
- 在編譯過程之中我們采用了**“二元交叉熵”(binary_crossentropy)**的損失函數(shù),該損失函數(shù)通常用作二元分類問題
- 因?yàn)樵跀?shù)據(jù)處理過程中我們沒有劃分 Batch ,因此我們要在訓(xùn)練(fit)的過程之中來定義 Batch_Size。
4. 程序的結(jié)果
運(yùn)行上面的程序,我們可以得到如下的輸出:
Model: "sequential_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding_4 (Embedding) (None, None, 32) 320000
_________________________________________________________________
global_average_pooling1d_3 ( (None, 32) 0
_________________________________________________________________
dense_8 (Dense) (None, 64) 2112
_________________________________________________________________
dense_9 (Dense) (None, 1) 65
=================================================================
Total params: 322,177
Trainable params: 322,177
Non-trainable params: 0
_________________________________________________________________
Epoch 1/30
196/196 [==============================] - 2s 10ms/step - loss: 0.6428 - accuracy: 0.6598 - val_loss: 0.5054 - val_accuracy: 0.8246
Epoch 2/30
196/196 [==============================] - 2s 10ms/step - loss: 0.3655 - accuracy: 0.8654 - val_loss: 0.3217 - val_accuracy: 0.8741
Epoch 3/30
196/196 [==============================] - 2s 10ms/step - loss: 0.2429 - accuracy: 0.9084 - val_loss: 0.2956 - val_accuracy: 0.8763
Epoch 4/30
196/196 [==============================] - 2s 10ms/step - loss: 0.1869 - accuracy: 0.9322 - val_loss: 0.2870 - val_accuracy: 0.8842
Epoch 5/30
196/196 [==============================] - 2s 10ms/step - loss: 0.1468 - accuracy: 0.9498 - val_loss: 0.2978 - val_accuracy: 0.8820
Epoch 6/30
196/196 [==============================] - 2s 10ms/step - loss: 0.1167 - accuracy: 0.9622 - val_loss: 0.3121 - val_accuracy: 0.8835
Epoch 7/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0915 - accuracy: 0.9737 - val_loss: 0.3375 - val_accuracy: 0.8786
Epoch 8/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0720 - accuracy: 0.9805 - val_loss: 0.3668 - val_accuracy: 0.8784
Epoch 9/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0558 - accuracy: 0.9870 - val_loss: 0.3917 - val_accuracy: 0.8747
Epoch 10/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0437 - accuracy: 0.9924 - val_loss: 0.4241 - val_accuracy: 0.8729
Epoch 11/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0345 - accuracy: 0.9946 - val_loss: 0.4539 - val_accuracy: 0.8696
Epoch 12/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0272 - accuracy: 0.9956 - val_loss: 0.4948 - val_accuracy: 0.8703
Epoch 13/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0201 - accuracy: 0.9974 - val_loss: 0.5199 - val_accuracy: 0.8679
Epoch 14/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0150 - accuracy: 0.9984 - val_loss: 0.5517 - val_accuracy: 0.8662
Epoch 15/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0122 - accuracy: 0.9987 - val_loss: 0.5818 - val_accuracy: 0.8646
Epoch 16/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0098 - accuracy: 0.9991 - val_loss: 0.6114 - val_accuracy: 0.8642
Epoch 17/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0080 - accuracy: 0.9993 - val_loss: 0.6514 - val_accuracy: 0.8632
Epoch 18/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0063 - accuracy: 0.9996 - val_loss: 0.6680 - val_accuracy: 0.8621
Epoch 19/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0047 - accuracy: 0.9997 - val_loss: 0.6967 - val_accuracy: 0.8620
Epoch 20/30
196/196 [==============================] - 2s 11ms/step - loss: 0.0039 - accuracy: 0.9998 - val_loss: 0.7308 - val_accuracy: 0.8611
Epoch 21/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0027 - accuracy: 1.0000 - val_loss: 0.7511 - val_accuracy: 0.8608
Epoch 22/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0023 - accuracy: 0.9999 - val_loss: 0.7780 - val_accuracy: 0.8601
Epoch 23/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0018 - accuracy: 1.0000 - val_loss: 0.8057 - val_accuracy: 0.8590
Epoch 24/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0016 - accuracy: 0.9999 - val_loss: 0.8214 - val_accuracy: 0.8606
Epoch 25/30
196/196 [==============================] - 2s 11ms/step - loss: 0.0012 - accuracy: 1.0000 - val_loss: 0.8376 - val_accuracy: 0.8602
Epoch 26/30
196/196 [==============================] - 2s 11ms/step - loss: 0.0010 - accuracy: 1.0000 - val_loss: 0.8689 - val_accuracy: 0.8592
Epoch 27/30
196/196 [==============================] - 2s 12ms/step - loss: 8.3966e-04 - accuracy: 1.0000 - val_loss: 0.8716 - val_accuracy: 0.8592
Epoch 28/30
196/196 [==============================] - 2s 10ms/step - loss: 7.2445e-04 - accuracy: 1.0000 - val_loss: 0.8918 - val_accuracy: 0.8588
Epoch 29/30
196/196 [==============================] - 2s 12ms/step - loss: 6.1936e-04 - accuracy: 1.0000 - val_loss: 0.9143 - val_accuracy: 0.8591
Epoch 30/30
196/196 [==============================] - 2s 10ms/step - loss: 5.2330e-04 - accuracy: 1.0000 - val_loss: 0.9336 - val_accuracy: 0.8596
782/782 [==============================] - 1s 2ms/step - loss: 0.9893 - accuracy: 0.8468
[0.9892528653144836, 0.8467599749565125]
由此可以看到,我們的網(wǎng)絡(luò)最終在測(cè)試集合上達(dá)到了 84.68% 的準(zhǔn)確率,同時(shí)它的損失為 0.9893 。
5. 小結(jié)
在這節(jié)課之中,我們學(xué)會(huì)了如何在機(jī)器學(xué)習(xí)之中處理文本數(shù)據(jù),同時(shí)了解了對(duì)文本進(jìn)行分類的基本步驟。
通過自己的動(dòng)手實(shí)現(xiàn),我們實(shí)現(xiàn)了一個(gè)分類準(zhǔn)確率接近 85% 的文本分類器。