1 回答

TA貢獻(xiàn)1821條經(jīng)驗(yàn) 獲得超6個(gè)贊
我認(rèn)為使用它glob2來獲取所有文件名,根據(jù)需要處理它們,然后創(chuàng)建一個(gè)簡單的加載函數(shù)來替換image_dataset_from_directory.
獲取您的所有文件:
files = glob2.glob('class_*\\*.jpg')
然后根據(jù)需要操作該文件名列表。
然后,創(chuàng)建一個(gè)加載圖像的函數(shù):
def load(file_path):
img = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.float32)
img = tf.image.resize(img, size=(299, 299))
label = tf.strings.split(file_path, os.sep)[0]
label = tf.cast(tf.equal(label, 'class_a'), tf.int32)
return img, label
然后創(chuàng)建用于訓(xùn)練的數(shù)據(jù)集:
train_ds = tf.data.Dataset.from_tensor_slices(files).map(load).batch(4)
然后訓(xùn)練:
model.fit(train_ds)
- 1 回答
- 0 關(guān)注
- 152 瀏覽
添加回答
舉報(bào)