1 回答

TA貢獻(xiàn)1946條經(jīng)驗(yàn) 獲得超4個(gè)贊
dataset.batch(16)更改為 后該錯(cuò)誤將得到解決dataset.padded_batch(16)。
下面是相同的修改后的代碼。
import tensorflow as tf
@tf.function()
def prepare_sample(annotation):
? ? annotation_parts = tf.strings.split(annotation, sep=' ')
? ? image_file_name = annotation_parts[0]
? ? image_file_path = tf.strings.join(["/images/", image_file_name])
? ? depth_image = tf.io.read_file(image_file_path)
? ? bboxes = tf.reshape(annotation_parts[1:], shape=[-1,4])
? ? return depth_image, bboxes
annotations = ['image1.png 1 2 3 4', 'image2.png 1 2 3 4 5 6 7 8']
dataset = tf.data.Dataset.from_tensor_slices(annotations)
dataset = dataset.shuffle(len(annotations))
dataset = dataset.map(prepare_sample)
dataset = dataset.padded_batch(16)
for image, bboxes in dataset:
? pass
添加回答
舉報(bào)