1 回答

TA貢獻(xiàn)1848條經(jīng)驗(yàn) 獲得超2個(gè)贊
這里有2個(gè)錯(cuò)誤:
如果您想重用
zip
對(duì)象,請(qǐng)創(chuàng)建一個(gè)列表。該物體在使用一次后就會(huì)耗盡。你可以這樣修復(fù)它:
custom_cv?=?[*zip(train_index_list,?test_index_list)]
交叉驗(yàn)證列表
cross_val_predict
應(yīng)該是實(shí)際數(shù)組的分區(qū)(每個(gè)樣本應(yīng)該只屬于一個(gè)測(cè)試集)。就你而言,事實(shí)并非如此。如果您考慮一下,堆疊交叉驗(yàn)證列表的輸出將產(chǎn)生長度為6 的數(shù)組,而原始y的長度為 5。您可以像這樣實(shí)現(xiàn)自定義交叉驗(yàn)證預(yù)測(cè):
def custom_cross_val_predict(clf, X, y, cv):
? ? y_pred, y_true = [], []
? ? for tr_idx, vl_idx in cv:
? ? ? ? X_tr, y_tr = X[tr_idx], y[tr_idx]
? ? ? ? X_vl, y_vl = X[vl_idx], y[vl_idx]
? ? ? ? clf.fit(X_tr, y_tr)
? ? ? ? y_true.extend(y_vl)
? ? ? ? y_pred.extend(clf.predict(X_vl))
? ? ? ??
? ? return y_true, y_pred
labels, predicted = custom_cross_val_predict(clf,X,y,cv=custom_cv)
print('Confusion matrix:',confusion_matrix(labels, predicted))
添加回答
舉報(bào)