我有以下簡單的例子:import tensorflow as tftensor1 = tf.constant(value = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])tensor2 = tf.constant(value = [20, 21, 22, 23])print(tensor1.shape)print(tensor2.shape)dataset = tf.data.Dataset.from_tensor_slices((tensor1, tensor2))print('Original dataset')for i in dataset: print(i)dataset = dataset.repeat(3)print('Repeated dataset')for i in dataset: print(i)如果我然后將其批處理dataset為:dataset = dataset.batch(3)print('Batched dataset')for i in dataset: print(i)正如預(yù)期的那樣,我收到:Batched dataset(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([20, 21, 22], dtype=int32)>)(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[10, 11, 12], [ 1, 2, 3], [ 4, 5, 6]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([23, 20, 21], dtype=int32)>)(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[ 7, 8, 9], [10, 11, 12], [ 1, 2, 3]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([22, 23, 20], dtype=int32)>)(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=array([[ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([21, 22, 23], dtype=int32)>)批處理數(shù)據(jù)集采用連續(xù)的元素。但是,當(dāng)我先進(jìn)行混音,然后進(jìn)行批處理時(shí):dataset = dataset.shuffle(3)print('Shuffled dataset')for i in dataset: print(i)dataset = dataset.batch(3)print('Batched dataset')for i in dataset: print(i)我正在使用 Google Colab 和TensorFlow 2.x.我的問題是:為什么在批處理之前進(jìn)行洗牌會(huì)導(dǎo)致batch返回非連續(xù)元素?感謝您的任何答復(fù)。
1 回答

12345678_0001
TA貢獻(xiàn)1802條經(jīng)驗(yàn) 獲得超5個(gè)贊
這就是洗牌的作用。你是這樣開始的:
[[1,?2,?3],?[4,?5,?6],?[7,?8,?9],?[10,?11,?12]]
您已指定,buffer_size=3
因此它會(huì)創(chuàng)建前 3 個(gè)元素的緩沖區(qū):
[[1,?2,?3],?[4,?5,?6],?[7,?8,?9]]
您指定了batch_size=3
,因此它將從此樣本中隨機(jī)選擇一個(gè)元素,并將其替換為初始緩沖區(qū)之外的第一個(gè)元素。假設(shè)[1, 2, 3]
已被選中,您的批次現(xiàn)在是:
[[1,?2,?3]]
現(xiàn)在你的緩沖區(qū)是:
[[10,?11,?12],?[4,?5,?6],?[7,?8,?9]]
對于 的第二個(gè)元素batch=3
,它將從此緩沖區(qū)中隨機(jī)選擇。假設(shè)[7, 8, 9]
已挑選,您的批次現(xiàn)在是:
[[1,?2,?3],?[7,?8,?9]]
現(xiàn)在你的緩沖區(qū)是:
[[10,?11,?12],?[4,?5,?6]]
沒有什么新內(nèi)容可以填充緩沖區(qū),因此它將隨機(jī)選擇這些元素之一,例如[10, 11, 12]
。您的批次現(xiàn)在是:
[[1,?2,?3],?[7,?8,?9],?[10,?11,?12]]
下一批將只是[4, 5, 6]
因?yàn)槟J(rèn)情況下,?batch(drop_remainder=False)
.
添加回答
舉報(bào)
0/150
提交
取消