2 回答

TA貢獻1833條經(jīng)驗 獲得超4個贊
我瀏覽了torchtext源代碼以更好地了解sort_key在做什么,并了解了為什么我的原始想法不起作用。
我不確定這是否是最好的解決方案,但是我想出了一個可行的解決方案。我創(chuàng)建了一個tokenizer函數(shù),如果它比最長的過濾器長度短,則填充文本,然后從那里創(chuàng)建BucketIterator。
FILTER_SIZES = [3,4,5]
spacy_en = spacy.load('en')
def tokenizer(text):
token = [t.text for t in spacy_en.tokenizer(text)]
if len(token) < FILTER_SIZES[-1]:
for i in range(0, FILTER_SIZES[-1] - len(token)):
token.append('<PAD>')
return token
TEXT = Field(sequential=True, tokenize=tokenizer, lower=True, tensor_type=torch.cuda.LongTensor)
train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text), batch_size=batch_size, repeat=False, device=device)

TA貢獻1853條經(jīng)驗 獲得超6個贊
盡管@ paul41的方法有效,但還是有些濫用。這樣做的正確方法是使用preprocessing或postprocessing(相應(yīng)地在數(shù)字化之前或之后)。這是一個示例postprocessing:
def get_pad_to_min_len_fn(min_length):
def pad_to_min_len(batch, vocab, min_length=min_length):
pad_idx = vocab.stoi['<pad>']
for idx, ex in enumerate(batch):
if len(ex) < min_length:
batch[idx] = ex + [pad_idx] * (min_length - len(ex))
return batch
return pad_to_min_len
FILTER_SIZES = [3,4,5]
min_len_padding = get_pad_to_min_len_fn(min_length=max(FILTER_SIZES))
TEXT = Field(sequential=True, use_vocab=True, lower=True, batch_first=True,
postprocessing=min_len_padding)
如果在主循環(huán)中定義了嵌套函數(shù)(例如min_length = max(FILTER_SIZES)),則需要將參數(shù)傳遞給內(nèi)部函數(shù),但如果可行,則可以在函數(shù)內(nèi)部對參數(shù)進行硬編碼。
添加回答
舉報