1 回答

TA貢獻(xiàn)1877條經(jīng)驗(yàn) 獲得超6個(gè)贊
模型需要一批/樣本列表。您可以通過(guò)在創(chuàng)建數(shù)據(jù)集時(shí)簡(jiǎn)單地設(shè)置批處理屬性來(lái)做到這一點(diǎn),如下所示:
ds = tf.data.Dataset.from_generator(data_generator, output_types=(tf.float32, tf.int32),
output_shapes=(tf.TensorShape([100, 3]), tf.TensorShape([5])))
ds = ds.batch(16)
您也可以在準(zhǔn)備樣品時(shí)采用另一種方式。這樣,您需要擴(kuò)展樣本維度,以便樣本充當(dāng)批次(您也可以傳遞樣本列表)并且您必須在output_shapes數(shù)據(jù)集和create_timeseries_element函數(shù)中進(jìn)行以下修改
def create_timeseries_element():
# returns a random time series of 100 intervals, each with 3 features,
# and a random one-hot array of 5 entries
# Expand dimensions to create a batch of single sample
data = np.expand_dims(np.random.rand(100, 3), axis=0)
label = np.expand_dims(np.eye(5, dtype='int')[np.random.choice(5)], axis=0)
return data, label
ds = tf.data.Dataset.from_generator(data_generator, output_types=(tf.float32, tf.int32), output_shapes=(tf.TensorShape([None, 100, 3]), tf.TensorShape([None, 5])))
上述更改將為數(shù)據(jù)集的每個(gè)時(shí)期僅提供一個(gè)批次(第一個(gè)解決方案的樣本)。您可以通過(guò)在定義數(shù)據(jù)集時(shí)將參數(shù)傳遞給data_generator函數(shù)來(lái)生成所需的批次(第一個(gè)解決方案的樣本)(例如 25 個(gè)),如下所示:
def data_generator(count=1):
for _ in range(count):
d, l = create_timeseries_element()
yield (d, l)
ds = tf.data.Dataset.from_generator(data_generator, args=[25], output_types=(tf.float32, tf.int32), output_shapes=(tf.TensorShape([None, 100, 3]), tf.TensorShape([None, 5])))
添加回答
舉報(bào)