我想用tensorflow重寫pytorch的torch.nn.functional.unfold函數(shù):#input x:[16, 1, 50, 36]x = torch.nn.functional.unfold(x, kernel_size=(5, 36), stride=3)#output x:[16, 180, 16]我嘗試使用該功能tf.extract_image_patches():x = tf.extract_image_patches(x,ksizes=[1, 1,5, 98],strides=[1, 1, 3, 1], rates=[1, 1, 1, 1],padding='VALID')輸入x.shape:[16,1,64,98]我得到輸出x.shape:[16,1,20,490]然后我將 重塑X為[16,490,20],這正是我所期望的。但是當(dāng)我輸入數(shù)據(jù)時(shí)出現(xiàn)錯(cuò)誤:UnimplementedError (see above for traceback): Only support ksizes across space.[[Node:hcn/ExtractImagePatches = ExtractImagePatches[T=DT_FLOAT, ksizes=[1, 1, 5, 98], padding="VALID", rates=[1, 1, 1, 1], strides=[1, 1, 3, 1], _device="/job:localhost/replica:0/task:0/device:GPU:0"](hcn/Reshape)]]我如何使用tensorflow重寫pytorchtorch.nn.functional.unfold函數(shù)來(lái)更改X?
1 回答

小怪獸愛(ài)吃肉
TA貢獻(xiàn)1852條經(jīng)驗(yàn) 獲得超1個(gè)贊
x = tf.reshape(x, [16, 50, 36, 1]) x = tf.extract_image_patches(x, ksizes=[1, 4, 98, 1], strides=[1, 4, 1, 1], rates=[1, 1, 1, 1], padding='VALID')
添加回答
舉報(bào)
0/150
提交
取消