2 回答

TA貢獻(xiàn)1816條經(jīng)驗 獲得超6個贊
這個表達(dá)式的意思是:對一個softmax_output形狀數(shù)組進(jìn)行切片,(N, C)從中只提取與訓(xùn)練標(biāo)簽相關(guān)的值y。
二維numpy.array可以用包含適當(dāng)值的兩個列表進(jìn)行切片(即它們不應(yīng)導(dǎo)致索引錯誤)
range(num_train)為第一個軸創(chuàng)建一個索引,允許使用第二個索引 - 選擇每行中的特定值list(y)。你可以在numpy 的 indexing 文檔中找到它。
第一個索引 range_num 的長度等于softmax_output(= N)的第一個維度。它指向矩陣的每一行;然后對于每一行,它通過索引的第二部分中的相應(yīng)值選擇目標(biāo)值 - list(y)。
例子:
softmax_output = np.array( # dummy values, not softmax
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]]
)
num_train = 4 # length of the array
y = [2, 1, 0, 2] # a labels; values for indexing along the second axis
softmax_output[range(num_train), list(y)]
Out:
[3, 5, 7, 12]
因此,它從第一行中選擇第三個元素,從第二行中選擇第二個元素,等等。這就是它的工作原理。
(ps 我誤解了你,你對“為什么”感興趣,而不是“如何”?)

TA貢獻(xiàn)1801條經(jīng)驗 獲得超16個贊
這里的損失由以下等式定義
這里,對于數(shù)據(jù)點(diǎn)所屬的類,y 為 1,對于所有其他類,y 為 0。因此,我們只對數(shù)據(jù)點(diǎn)類的 softmax 輸出感興趣。因此上面的方程可以改寫為
因此,下面的代碼表示上述等式。
loss = -np.sum(np.log(softmax_output[range(num_train), list(y)]))
該代碼softmax_output[range(num_train), list(y)]
用于為各個類選擇 softmax 輸出。range(num_train)
代表所有訓(xùn)練樣本并list(y)
代表各自的類別。
Mikhail 在他的回答中很好地解釋了這種索引。
添加回答
舉報