2 回答

TA貢獻(xiàn)1834條經(jīng)驗(yàn) 獲得超8個(gè)贊
對(duì)于這個(gè)確切的用例,還要考慮
a * (b <= 0.5)
這似乎是以下最快的
In [1]: import torch
...: a = torch.rand(3**10)
...: b = torch.rand(3**10)
In [2]: %timeit a[b > 0.5] = 0.
553 μs ± 17.2 μs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [3]: a = torch.rand(3**10)
In [4]: %timeit temp = torch.where(b > 0.5, torch.tensor(0.), a)
...:
49 μs ± 391 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [5]: a = torch.rand(3**10)
In [6]: %timeit temp = (a * (b <= 0.5))
44 μs ± 381 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [7]: %timeit a.masked_fill_(b > 0.5, 0.)
244 μs ± 3.48 μs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

TA貢獻(xiàn)1934條經(jīng)驗(yàn) 獲得超2個(gè)贊
我想torch.where會(huì)更快我在 CPU 中的測(cè)量是結(jié)果。
import torch
a = torch.rand(3**10)
b = torch.rand(3**10)
%timeit a[b > 0.5] = 0.
852 μs ± 30.2 μs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit temp = torch.where(b > 0.5, torch.tensor(0.), a)
294 μs ± 4.51 μs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
添加回答
舉報(bào)