1 回答

TA貢獻(xiàn)1805條經(jīng)驗(yàn) 獲得超9個贊
你可以得到3維指數(shù)的最大值從max_idx。中的值max_idx是最大值沿軸 1 的索引。有六個值,因?yàn)槟钠渌S是 3 和 2 (3 x 2 = 6)。您只需要了解 numpy 通過它們獲取其他每個軸的索引的順序。您首先遍歷最后一個軸:
d0, d1, d2 = A.shape
a0 = [i for i in range(d0) for _ in range(d2)] # [0, 0, 1, 1, 2, 2]
a1 = max_idx.flatten() # [2, 2, 0, 2, 0, 1]
a2 = [k for _ in range(d0) for k in range(d2)] # [0, 1, 0, 1, 0, 1]
B[a0, a1, a2] = A[a0, a1, a2]
輸出:
array([[[0. , 0. ],
[0. , 0. ],
[0.94485653, 0.9264881 ]],
[[0.95446736, 0. ],
[0. , 0. ],
[0. , 0.36436023]],
[[0.56911013, 0. ],
[0. , 0.96278067],
[0. , 0. ]]])
添加回答
舉報(bào)