1 回答

TA貢獻(xiàn)1798條經(jīng)驗(yàn) 獲得超7個(gè)贊
定義一個(gè)函數(shù),它返回一個(gè)鍵:每行的值字典,鍵是標(biāo)簽,值是基于閾值的 1/0
def get_label_score_dict(row, threshold):
result_dict = dict()
for _label, _score in zip(row['labels'], row['scores']):
if _score > threshold:
result_dict.update({_label: 1})
else:
result_dict.update({_label: 0})
return result_dict
現(xiàn)在,如果您有一個(gè)list_of_rows,其中每一行都采用如上所示的形式,那么您可以使用map函數(shù)為每一行獲取上述字典。一旦你得到它,將它轉(zhuǎn)換成一個(gè) DataFrame。
th = 0.5 #whatever threshold value you want
result = list(map(lambda x: get_label_score_dict(x, th), list_of_rows))
result_df = pd.DataFrame(result)
添加回答
舉報(bào)