1 回答
TA貢獻1798條經(jīng)驗 獲得超7個贊
定義一個函數(shù),它返回一個鍵:每行的值字典,鍵是標(biāo)簽,值是基于閾值的 1/0
def get_label_score_dict(row, threshold):
result_dict = dict()
for _label, _score in zip(row['labels'], row['scores']):
if _score > threshold:
result_dict.update({_label: 1})
else:
result_dict.update({_label: 0})
return result_dict
現(xiàn)在,如果您有一個list_of_rows,其中每一行都采用如上所示的形式,那么您可以使用map函數(shù)為每一行獲取上述字典。一旦你得到它,將它轉(zhuǎn)換成一個 DataFrame。
th = 0.5 #whatever threshold value you want
result = list(map(lambda x: get_label_score_dict(x, th), list_of_rows))
result_df = pd.DataFrame(result)
添加回答
舉報
