1 回答

TA貢獻(xiàn)1805條經(jīng)驗(yàn) 獲得超9個(gè)贊
您可以在 numba jitted 函數(shù)中使用 np.prod:
n = 3
lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
flat = np.ravel(arr).tolist()
gen = [list(a) for a in product(flat, repeat=n)]
@jit(nopython=True, parallel=True)
def mtp(gen):
results = np.empty(len(gen))
for i in prange(len(gen)):
results[i] = np.prod(gen[i])
return results
或者,您可以使用如下所示的reduce(感謝@stuartarchibald指出這一點(diǎn)),盡管并行化在下面不起作用(至少?gòu)膎umba 0.48開(kāi)始):
import numpy as np
from itertools import product
from functools import reduce
from operator import mul
from numba import njit, prange
lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
n = 3
flat = np.ravel(arr).tolist()
gen = np.array([list(a) for a in product(flat, repeat=n)])
@njit
def mul_wrapper(x, y):
return mul(x, y)
@njit
def mtp(gen):
results = np.empty(gen.shape[0])
for i in prange(gen.shape[0]):
results[i] = reduce(mul_wrapper, gen[i], None)
return results
print(mtp(gen))
或者,因?yàn)镹umba內(nèi)部有一點(diǎn)魔力,可以發(fā)現(xiàn)將轉(zhuǎn)義函數(shù)并編譯它們的閉包。(再次感謝@stuartarchibald),你可以這樣,在下面:
@njit
def mtp(gen):
results = np.empty(gen.shape[0])
def op(x, y):
return mul(x, y)
for i in prange(gen.shape[0]):
results[i] = reduce(op, gen[i], None)
return results
但同樣,并行在numba 0.48之前在這里不起作用。
請(qǐng)注意,核心開(kāi)發(fā)團(tuán)隊(duì)成員推薦的方法是采用第一個(gè)使用 .它可以與并行標(biāo)志一起使用,并具有更直接的實(shí)現(xiàn)。np.prod
添加回答
舉報(bào)