第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

為了賬號(hào)安全,請(qǐng)及時(shí)綁定郵箱和手機(jī)立即綁定
已解決430363個(gè)問(wèn)題,去搜搜看,總會(huì)有你想問(wèn)的

使用 TensorFlow 模型評(píng)估輸入的簡(jiǎn)單方法?

使用 TensorFlow 模型評(píng)估輸入的簡(jiǎn)單方法?

猛跑小豬 2023-07-11 16:18:41
在這里,我有一個(gè)使用生成的數(shù)據(jù)訓(xùn)練的增強(qiáng)決策樹,并保存為est:from sklearn.datasets import make_blobsimport pandas as pdimport tensorflow as tf#creates an input function for a tf modeldef make_input_fn(X, Y, n_epochs=None, shuffle=True, verbose=False):    batch_len = len(Y)    def input_fn():        dataset = tf.data.Dataset.from_tensor_slices((dict(X), Y))        if shuffle:            dataset = dataset.shuffle(batch_len)        # For training, cycle thru dataset as many times as need (n_epochs=None).        dataset = dataset.repeat(n_epochs)        #dividing data into batches        dataset = dataset.batch(batch_len)        return dataset    return input_fn#making datatrainX, trainY = make_blobs(n_samples=10, centers=2, n_features=3, random_state=0)#xValstrainX = pd.DataFrame(trainX)trainX.columns = ['feature{}'.format(num) for num in trainX.columns]#yValstrainY = pd.DataFrame(trainY)trainY.columns = ['flag']# Defining input functiontrain_input_fn = make_input_fn(trainX, trainY)#defining tf feature columnsfeature_columns=[]for feature_name in list(trainX.columns):    feature_columns.append(tf.feature_column.numeric_column(feature_name,dtype=tf.float32))    #creating the estimatorn_batches = 1est = tf.estimator.BoostedTreesClassifier(feature_columns, n_batches_per_layer=n_batches)est.train(train_input_fn, max_steps=10)我想使用該模型根據(jù)一行訓(xùn)練數(shù)據(jù)進(jìn)行預(yù)測(cè)以用于測(cè)試目的;像這樣的事情:res = est.predict(trainX.loc[0])但是,我很難弄清楚如何去做。
查看完整描述

1 回答

?
慕哥6287543

TA貢獻(xiàn)1831條經(jīng)驗(yàn) 獲得超10個(gè)贊

您必須像訓(xùn)練時(shí)一樣創(chuàng)建輸入函數(shù)。

代碼:


def my_input_fn(features, batch_size=256):

    """An input function for prediction."""

    # Convert the inputs to a Dataset without labels.

    return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)


testX = pd.DataFrame(trainX.loc[0]).T


predictions = est.predict(

    input_fn=lambda: my_input_fn(testX))

預(yù)測(cè)將為您提供一個(gè)生成器對(duì)象。你必須迭代它才能獲得預(yù)測(cè)


for pred_dict in predictions:

    class_id = pred_dict['class_ids'][0]

    probability = pred_dict['probabilities'][class_id]

    print(class_id, probability)

class_id是預(yù)測(cè)的ID。


請(qǐng)注意,pred_dict 還包含其他信息。


以下是 pred_dict 中包含的信息:


{'all_class_ids': array([0, 1]),

 'all_classes': array([b'0', b'1'], dtype=object),

 'class_ids': array([0], dtype=int64),

 'classes': array([b'0'], dtype=object),

 'logistic': array([0.17926924], dtype=float32),

 'logits': array([-1.5213063], dtype=float32),

 'probabilities': array([0.82073075, 0.17926925], dtype=float32)}


查看完整回答
反對(duì) 回復(fù) 2023-07-11
  • 1 回答
  • 0 關(guān)注
  • 112 瀏覽
慕課專欄
更多

添加回答

舉報(bào)

0/150
提交
取消
微信客服

購(gòu)課補(bǔ)貼
聯(lián)系客服咨詢優(yōu)惠詳情

幫助反饋 APP下載

慕課網(wǎng)APP
您的移動(dòng)學(xué)習(xí)伙伴

公眾號(hào)

掃描二維碼
關(guān)注慕課網(wǎng)微信公眾號(hào)