使用圖像數(shù)據(jù)來訓(xùn)練模型
在之前的學(xué)習(xí)中,我們曾經(jīng)學(xué)習(xí)過使用 Keras 進(jìn)行圖片分類。具體來說,我們學(xué)習(xí)了:
- 將二位圖片數(shù)據(jù)進(jìn)行扁平化處理;
- 將圖片數(shù)據(jù)使用卷積神經(jīng)網(wǎng)絡(luò)進(jìn)行處理。
然而在實(shí)際的機(jī)器學(xué)習(xí)之中,當(dāng)我們使用圖片數(shù)據(jù)來訓(xùn)練模型的時(shí)候,我們會(huì)用到更多的操作。因此在這節(jié)課之中我們便整體地了解一下如何使用圖像數(shù)據(jù)來構(gòu)建數(shù)據(jù)集。
在實(shí)際的應(yīng)用過程中,我們最常用的圖片數(shù)據(jù)加載方式一共有三種,因此這節(jié)課我們主要學(xué)習(xí)這三種主要地圖片加載方式:
- 使用 TFRecord 構(gòu)建圖片數(shù)據(jù)集;
- 使用 tf.keras.preprocessing.image.ImageDataGenerator 構(gòu)建圖片數(shù)據(jù)集;
- 使用 tf.data.Dataset 原生方法構(gòu)建數(shù)據(jù)集。
在這節(jié)課之中,我們使用之前用過的貓狗分類的數(shù)據(jù)集之中的貓的訓(xùn)練集的圖片進(jìn)行測試,具體來說,我們可以通過以下代碼準(zhǔn)備具體的數(shù)據(jù)集:
import tensorflow as tf
import os
dataset_url = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_download = os.path.dirname(tf.keras.utils.get_file('cats_and_dogs.zip', origin=dataset_url, extract=True))
cat_train_dir = path_download + '/cats_and_dogs_filtered/train/cats'
這樣,cat_train_dir 就是我們要測試的圖片的路徑。
1. 使用TFRecord構(gòu)建圖片數(shù)據(jù)集
TFRecord 是一種二進(jìn)制的數(shù)據(jù)文件,也正是因?yàn)?TFRecord 是一種二進(jìn)制的數(shù)據(jù)文件,因此他的讀寫速度較快,同時(shí)也不會(huì)產(chǎn)生編碼錯(cuò)誤之類的問題。
使用 TFRecord 主要包括兩個(gè)步驟:
- 生成 TFRecord 文件并進(jìn)行存儲(chǔ);
- 讀取 TFRecord 文件,并用于訓(xùn)練。
1. 生成 TFRecord 文件并進(jìn)行存儲(chǔ)
既然我們已經(jīng)獲得了圖片文件所在的目錄,那么我們便可以生成 TFRecord 文件:
from PIL import Image
# 打開TFRecord文件
writer = tf.io.TFRecordWriter('./cat_data')
for img_path in os.listdir(cat_train_dir):
# 讀取并將圖片Resize
img = os.path.join(cat_train_dir, img_path)
img = Image.open(img)
img = img.convert('RGB').resize((32,32)).tobytes()
# 定義標(biāo)簽,假設(shè)貓的標(biāo)簽是0
label = 0 # 0:cat, 1:dog
# 構(gòu)建一條數(shù)據(jù)
example = tf.train.Example(
features = tf.train.Features(
feature = {
'label': tf.train.Feature(int64_list=tf.train.Int64List (value=[int(label)])),
'data' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[img]))
}
)
)
# 將數(shù)據(jù)寫入
writer.write(example.SerializeToString())
writer.close()
如上述代碼所示,我們首先需要打開 TFRecord 文件,然后再保存結(jié)束時(shí)再將其關(guān)閉。
其次我們首先使用讀取了圖片文件,然后將其進(jìn)行了以下處理:
- 轉(zhuǎn)化為 RGB 模式;
- Resize 到 (32,32 )大小;
- 轉(zhuǎn)化為二進(jìn)制字節(jié)數(shù)據(jù)。
最后我們使用 tf.train.Example 函數(shù)將每一條數(shù)據(jù)按照 label 和 data 的形式進(jìn)行封裝,并寫入到 TFRecord 文件之中。
2. 讀取 TFRecord 文件
在讀取的時(shí)候,我們會(huì)將 TFRecord 文件讀入到內(nèi)存之中,并且轉(zhuǎn)化為 tf.data.Dataset ,以便日后使用。
cat_reader = tf.data.TFRecordDataset('./cat_data')
def decode_image(example):
# 加載單條數(shù)據(jù)
single_example = tf.io.parse_single_example(
example,
{
'data' : tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)
}
)
img = single_example['data']
label = single_example['label']
# 圖片處理
img = tf.io.decode_raw(img, tf.uint8)
img = tf.reshape(img, [32, 32, 3])
return (img, label)
# 映射并分批次
cat_dataset = cat_reader.map(decode_image).batch(32)
print(cat_dataset)
這其中有幾點(diǎn)需要注意:
- 首先我們需要根據(jù)存儲(chǔ)的路徑來載入 TFRecord ;
- 我們需要使用一個(gè)函數(shù)來處理每一條數(shù)據(jù),這個(gè)函數(shù)可以通過 cat_reader.map() 來調(diào)用;
- 在 decode_image 之中:
- tf.io.parse_single_example 函數(shù)用于加載每一條數(shù)據(jù),它接收兩個(gè)參數(shù),第一個(gè)是當(dāng)前數(shù)據(jù),第二個(gè)是數(shù)據(jù)的格式;
- 我們又采用了 tf.io.decode_raw 函數(shù)來對圖片進(jìn)行了解碼,將其轉(zhuǎn)化為數(shù)字類型。
- 最后我們將圖片數(shù)據(jù)分批次,大小為32 。
于是我們可以得到輸出為:
<BatchDataset shapes: ((None, 32, 32, 3), (None,)), types: (tf.uint8, tf.int64)>
由此可見,我們正確地加載了該數(shù)據(jù)集。
2.使用 tf.keras.preprocessing.image.ImageDataGenerator 構(gòu)建圖片數(shù)據(jù)集
使用這種方式會(huì)非常簡單,我們只需要一條語句即可實(shí)現(xiàn):
cat_generator = tf.keras.preprocessing.image.ImageDataGenerator().flow_from_directory(
directory=path_download + '/cats_and_dogs_filtered/train',
target_size=(32, 32),
batch_size=32,
shuffle = True,
class_mode='binary')
print(cat_generator)
我們可以得到如下輸出:
Found 2000 images belonging to 2 classes.
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x7f28d0c4a048>
在使用的過程中, directory 參數(shù)需要我們注意,該路徑應(yīng)該是圖片路徑之外的一層路徑。
也就是說,如果圖片路徑為“/a/b/c.jpg”,那么我們要傳入的路徑應(yīng)該是“/a”。
其余的參數(shù)為:
- target_size: 圖片的大??;
- batch_size: 批次大小;
- shuffle: 是否亂序;
- class_modle: 若是binary則為二分類,multi則為多分類。
由于我們得到的數(shù)據(jù)集是一個(gè)迭代器,因此我們不能使用常用的 fit 方式來訓(xùn)練,我們可以通過以下方式進(jìn)行訓(xùn)練:
model.fit_generator(cat_generator)
3. 使用 tf.data.Dataset 原生方法構(gòu)建數(shù)據(jù)集
使用這種方法也非常簡單,我們需要兩個(gè)步驟來進(jìn)行數(shù)據(jù)集的構(gòu)建:
- 定義圖片加載函數(shù);
- 使用 tf.data.Dataset 構(gòu)建數(shù)據(jù)集。
于是我們可以使用如下代碼進(jìn)行數(shù)據(jù)集的構(gòu)建:
def load_image(img_path):
label = tf.constant(0,tf.int8)
img = tf.io.read_file(img_path)
img = tf.image.decode_jpeg(img)
img = tf.image.resize(img, (32, 32))
return (img,label)
cat_dataset = tf.data.Dataset.list_files(cat_train_dir).map(load_image).batch(32)
print(cat_dataset)
在這段程序中,我們首先在載入圖片函數(shù)中進(jìn)行了如下處理:
- 定義標(biāo)簽,因?yàn)槿渴秦垼虼宋覀冊O(shè)置為 0 ;
- 使用 tf.io.read_file 讀取文件;
- 因?yàn)槲覀兊膱D片都是 jpeg 格式,因此我們使用 tf.image.decode_jpeg 來解碼圖片;
- 最后使用 tf.image.resize 來對圖片進(jìn)行尺寸調(diào)整,統(tǒng)一為(32, 32)。
然后我們使用 tf.data.Dataset.list_files() 函數(shù)構(gòu)建了數(shù)據(jù)集,它接收的第一個(gè)參數(shù)就是圖片所在的文件夾。
我們可以得到輸出:
<BatchDataset shapes: ((None, 32, 32, None), (None,)), types: (tf.float32, tf.int8)>
可見我們已經(jīng)成功地構(gòu)建了數(shù)據(jù)集。
4. 小結(jié)
在這節(jié)課之中,我們學(xué)習(xí)了三種圖片數(shù)據(jù)加載的方式,他們分別是:
- 使用 TFRecord 構(gòu)建圖片數(shù)據(jù)集;
- 使用 tf.keras.preprocessing.image.ImageDataGenerator 構(gòu)建圖片數(shù)據(jù)集;
- 使用 tf.data.Dataset 原生方法構(gòu)建數(shù)據(jù)集。
其中第一種方式最為快速,而第二種方式更為方便,我們可以根據(jù)自己的實(shí)際需求來進(jìn)行選擇。