2 回答

TA貢獻(xiàn)1827條經(jīng)驗(yàn) 獲得超8個(gè)贊
TF 無(wú)法在您編碼時(shí)工作。你應(yīng)該:
從原始網(wǎng)絡(luò)導(dǎo)出瓶頸到文件。
使用瓶頸結(jié)果作為輸入,使用另一個(gè)網(wǎng)絡(luò)來(lái)訓(xùn)練您的數(shù)據(jù)。

TA貢獻(xiàn)1872條經(jīng)驗(yàn) 獲得超4個(gè)贊
這樣的事情應(yīng)該工作(未經(jīng)測(cè)試):
# Serialize the data into two tfrecord files
tf.enable_eager_execution()
feature_extractor = ...
features_file = tf.python_io.TFRecordWriter('features.tfrec')
label_file = tf.python_io.TFRecordWriter('labels.tfrec')
for images, labels in dataset:
features = feature_extractor(images)
features_file.write(tf.serialize_tensor(features))
label_file.write(tf.serialize_tensor(labels))
# Parse the files and zip them together
def parse(type, shape):
_def parse(x):
result = tf.parse_tensor(x, out_type=shape)
result = tf.reshape(result, FEATURE_SHAPE)
return result
return parse
features_ds = tf.data.TFRecordDataset('features.tfrec')
features_ds = features_ds.map(parse(tf.float32, FEATURE_SHAPE), num_parallel_calls=AUTOTUNE)
labels_ds = tf.data.TFRecordDataset('labels.tfrec')
labels_ds = labels_ds.map(parse(tf.float32, FEATURE_SHAPE), num_parallel_calls=AUTOTUNE)
ds = tf.data.Dataset.zip(features_ds, labels_ds)
ds = ds.unbatch().shuffle().repeat().batch().prefetch()...
您也可以使用 來(lái)完成它Dataset.cache,但我不是 100% 確定細(xì)節(jié)。
添加回答
舉報(bào)