2 回答

TA貢獻1796條經(jīng)驗 獲得超4個贊
這可能與 GPU 可以并行計算a @ b.t()
. 這意味著 GPU 實際上不必等待每個行列乘法計算完成即可計算下一個乘法。如果您檢查 CPU,您會發(fā)現(xiàn)它torch.diag(a @ b.t())
比torch.einsum('ij,ij->i',a,b)
大型a
和b
.

TA貢獻1865條經(jīng)驗 獲得超7個贊
我不能代表,但幾年前曾在一些細節(jié)上torch合作過。np.einsum然后它根據(jù)索引字符串構(gòu)造一個自定義迭代器,僅執(zhí)行必要的計算。從那時起,它以各種方式進行了重新設(shè)計,顯然將問題轉(zhuǎn)化為@可能的情況,從而利用了 BLAS(等)庫調(diào)用。
In [147]: a = np.arange(12).reshape(3,4)
In [148]: b = a
In [149]: np.einsum('ij,ij->i', a,b)
Out[149]: array([ 14, 126, 366])
我不能確定在這種情況下使用了什么方法。通過“j”求和,還可以通過以下方式完成:
In [150]: (a*b).sum(axis=1)
Out[150]: array([ 14, 126, 366])
正如您所注意到的,最簡單的方法dot創(chuàng)建一個更大的數(shù)組,我們可以從中拉出對角線:
In [151]: (a@b.T).shape
Out[151]: (3, 3)
但這不是正確的使用方法@。 通過提供高效的“批量”處理@進行擴展。np.dot所以i維度是批次一,也是j一dot。
In [152]: a[:,None,:]@b[:,:,None]
Out[152]:
array([[[ 14]],
[[126]],
[[366]]])
In [156]: (a[:,None,:]@b[:,:,None])[:,0,0]
Out[156]: array([ 14, 126, 366])
換句話說,它使用 (3,1,4) 和 (3,4,1) 生成 (3,1,1),在共享大小 4 維度上進行乘積之和。
一些采樣時間:
In [162]: timeit np.einsum('ij,ij->i', a,b)
7.07 μs ± 89.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [163]: timeit (a*b).sum(axis=1)
9.89 μs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [164]: timeit np.diag(a@b.T)
10.6 μs ± 31.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [165]: timeit (a[:,None,:]@b[:,:,None])[:,0,0]
5.18 μs ± 197 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
添加回答
舉報