1 回答

TA貢獻(xiàn)1831條經(jīng)驗(yàn) 獲得超10個(gè)贊
您必須像訓(xùn)練時(shí)一樣創(chuàng)建輸入函數(shù)。
代碼:
def my_input_fn(features, batch_size=256):
"""An input function for prediction."""
# Convert the inputs to a Dataset without labels.
return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)
testX = pd.DataFrame(trainX.loc[0]).T
predictions = est.predict(
input_fn=lambda: my_input_fn(testX))
預(yù)測(cè)將為您提供一個(gè)生成器對(duì)象。你必須迭代它才能獲得預(yù)測(cè)
for pred_dict in predictions:
class_id = pred_dict['class_ids'][0]
probability = pred_dict['probabilities'][class_id]
print(class_id, probability)
class_id是預(yù)測(cè)的ID。
請(qǐng)注意,pred_dict 還包含其他信息。
以下是 pred_dict 中包含的信息:
{'all_class_ids': array([0, 1]),
'all_classes': array([b'0', b'1'], dtype=object),
'class_ids': array([0], dtype=int64),
'classes': array([b'0'], dtype=object),
'logistic': array([0.17926924], dtype=float32),
'logits': array([-1.5213063], dtype=float32),
'probabilities': array([0.82073075, 0.17926925], dtype=float32)}
添加回答
舉報(bào)