1 回答

TA貢獻(xiàn)1852條經(jīng)驗(yàn) 獲得超7個(gè)贊
用作inGroupKFold
的參數(shù):cv
cross_val_predict()
scores = cross_val_score(model, X, y, groups, cv=GroupKFold())
請(qǐng)注意,groups
數(shù)組表示我們希望在同一訓(xùn)練/測(cè)試集中的數(shù)據(jù)組。它不是類標(biāo)簽數(shù)組。
例如:
from sklearn.datasets import make_blobs
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GroupKFold, cross_val_score
X, y = make_blobs(n_samples=15, random_state=0)
model = LogisticRegression()
groups = [0,0,0,1,1,1,1,2,2,2,2,3,3,3,3]
scores = cross_val_score(model, X, y, groups, cv=GroupKFold(n_splits=3))
print('cross val scores: {}'.format(scores))
- 1 回答
- 0 關(guān)注
- 99 瀏覽
添加回答
舉報(bào)