2 回答

TA貢獻(xiàn)1810條經(jīng)驗(yàn) 獲得超5個贊
我發(fā)現(xiàn)了問題——如果你在使用 Tensorflow 數(shù)據(jù)集(tf.data.Dataset)時得到意外的維度,可能是因?yàn)闆]有運(yùn)行.batch
。
所以在我的例子中:
features = convert_examples_to_tf_dataset(test_examples, tokenizer)
添加:
features = features.batch(BATCH_SIZE)
使這項工作如我所料。所以,這不是與 相關(guān)的問題TFBertForSequenceClassification
,只是因?yàn)槲业妮斎氩徽_。我還想添加對這個答案的引用,這讓我發(fā)現(xiàn)了問題。

TA貢獻(xiàn)1828條經(jīng)驗(yàn) 獲得超3個贊
我報告了我的示例,其中我嘗試預(yù)測 3 個文本樣本并獲得 (3, 42) 作為輸出形狀
### define model
config = BertConfig.from_pretrained(
'bert-base-multilingual-cased',
num_labels=42,
output_hidden_states=False,
output_attentions=False
)
model = TFBertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', config=config)
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-05, epsilon=1e-08)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')
model.compile(optimizer=optimizer,
loss=loss,
metrics=[metric])
### import tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
### utility functions for text encoding
def return_id(str1, str2, length):
inputs = tokenizer.encode_plus(str1, str2,
add_special_tokens=True,
max_length=length)
input_ids = inputs["input_ids"]
input_masks = [1] * len(input_ids)
input_segments = inputs["token_type_ids"]
padding_length = length - len(input_ids)
padding_id = tokenizer.pad_token_id
input_ids = input_ids + ([padding_id] * padding_length)
input_masks = input_masks + ([0] * padding_length)
input_segments = input_segments + ([0] * padding_length)
return [input_ids, input_masks, input_segments]
### encode 3 sentences
input_ids, input_masks, input_segments = [], [], []
for instance in ['hello hello', 'ciao ciao', 'marco marco']:
ids, masks, segments = \
return_id(instance, None, 100)
input_ids.append(ids)
input_masks.append(masks)
input_segments.append(segments)
input_ = [np.asarray(input_ids, dtype=np.int32),
np.asarray(input_masks, dtype=np.int32),
np.asarray(input_segments, dtype=np.int32)]
### make prediction
model.predict(input_).shape # ===> (3,42)
添加回答
舉報