2 回答

TA貢獻(xiàn)1876條經(jīng)驗(yàn) 獲得超6個(gè)贊
您需要進(jìn)行一些更改才能使其正常工作:
使用 : 為 Numba 添加維度
arr[:, :, None]
,看起來(lái)getitem
更喜歡使用reshape
使用
np.abs
而不是內(nèi)置abs
argmin
withaxis
關(guān)鍵字參數(shù)未實(shí)現(xiàn)。更喜歡使用 Numba 旨在優(yōu)化的循環(huán)。
修復(fù)所有這些后,您可以運(yùn)行 jited 函數(shù):
from numba import jit
import numpy as np
inflow = np.arange(1,0,-0.01) # Dim [T]
actions = np.arange(0,1,0.05) # Dim [M]
start_lvl = np.random.rand(500).reshape(-1,1)*49 # Dim [Nx1]
disc_lvl = np.arange(0,1000) # Dim [O]
@jit(nopython=True)
def my_func(disc_lvl, actions, start_lvl, inflow):
for i in range(0,100):
# Calculate new level at time i
new_lvl = start_lvl + inflow[i] + actions # Dim [N x M]
# For each new_level element, find closest discretized level
new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)
diff = np.abs(disc_lvl - new_lvl_3d) # Dim [N x M x O]
idx_lvl = np.empty(new_lvl.shape)
for i in range(diff.shape[0]):
for j in range(diff.shape[1]):
idx_lvl[i, j] = diff[i, j, :].argmin()
return True
# function works fine without numba
success = my_func(disc_lvl, actions, start_lvl, inflow)

TA貢獻(xiàn)1875條經(jīng)驗(yàn) 獲得超3個(gè)贊
在我的第一篇文章的更正代碼下方找到,您可以在使用和不使用 numba 庫(kù)的 jitted 模式的情況下執(zhí)行(通過(guò)刪除以 @jit 開(kāi)頭的行)。我觀察到這個(gè)例子的速度增加了 2 倍。
from numba import jit
import numpy as np
import datetime as dt
inflow = np.arange(1,0,-0.01) # Dim [T]
nbTime = np.shape(inflow)[0]
actions = np.arange(0,1,0.01) # Dim [M]
start_lvl = np.random.rand(500).reshape(-1,1)*49 # Dim [Nx1]
disc_lvl = np.arange(0,1000) # Dim [O]
@jit(nopython=True)
def my_func(nbTime, disc_lvl, actions, start_lvl, inflow):
# Initialize result
res = np.empty((nbTime,np.shape(start_lvl)[0],np.shape(actions)[0]))
for t in range(0,nbTime):
# Calculate new level at time t
new_lvl = start_lvl + inflow[t] + actions # Dim [N x M]
print(t)
# For each new_level element, find closest discretized level
new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)
diff = np.abs(disc_lvl - new_lvl_3d) # Dim [N x M x O]
idx_lvl = np.empty(new_lvl.shape)
for i in range(diff.shape[0]):
for j in range(diff.shape[1]):
idx_lvl[i, j] = diff[i, j, :].argmin()
res[t,:,:] = idx_lvl
return res
# Call function and print running time
start_time = dt.datetime.now()
result = my_func(nbTime, disc_lvl, actions, start_lvl, inflow)
print('Execution time :',(dt.datetime.now() - start_time))
添加回答
舉報(bào)