第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

為了賬號(hào)安全,請(qǐng)及時(shí)綁定郵箱和手機(jī)立即綁定
已解決430363個(gè)問(wèn)題,去搜搜看,總會(huì)有你想問(wèn)的

Numba 和多維添加 - 不適用于 numpy.newaxis?

Numba 和多維添加 - 不適用于 numpy.newaxis?

慕斯王 2022-10-18 16:32:36
嘗試在 python 上加速 DP 算法,numba 似乎是一個(gè)合適的候選者。我正在用提供 3D 數(shù)組的 1D 數(shù)組減去 2D 數(shù)組。然后我使用.argmin()第三維來(lái)獲得一個(gè)二維數(shù)組。這適用于 numpy,但不適用于 numba。重現(xiàn)問(wèn)題的玩具代碼:from numba import jitimport numpy as npinflow      = np.arange(1,0,-0.01)                  # Dim [T]actions     = np.arange(0,1,0.05)                   # Dim [M]start_lvl   = np.random.rand(500).reshape(-1,1)*49  # Dim [Nx1]disc_lvl    = np.arange(0,1000)                     # Dim [O]@jit(nopython=True)def my_func(disc_lvl, actions, start_lvl, inflow):    for i in range(0,100):        # Calculate new level at time i        new_lvl = start_lvl + inflow[i] + actions       # Dim [N x M]        # For each new_level element, find closest discretized level        diff    = (disc_lvl-new_lvl[:,:,np.newaxis])    # Dim [N x M x O]        idx_lvl = abs(diff).argmin(axis=2)              # Dim [N x M]        return True# function works fine without numbasuccess = my_func(disc_lvl, actions, start_lvl, inflow)為什么上面的代碼不運(yùn)行?取出時(shí)會(huì)這樣@jit(nopython=True)。是否有一個(gè)工作回合可以使以下計(jì)算與 numba 一起工作?我嘗試了帶有 numpy repeats 和 expand_dims 的變體,以及明確定義 jit 函數(shù)的輸入類型但沒(méi)有成功。
查看完整描述

2 回答

?
HUX布斯

TA貢獻(xiàn)1876條經(jīng)驗(yàn) 獲得超6個(gè)贊

您需要進(jìn)行一些更改才能使其正常工作:

  1. 使用 : 為 Numba 添加維度arr[:, :, None],看起來(lái)getitem更喜歡使用reshape

  2. 使用np.abs而不是內(nèi)置abs

  3. argminwithaxis關(guān)鍵字參數(shù)未實(shí)現(xiàn)。更喜歡使用 Numba 旨在優(yōu)化的循環(huán)。

修復(fù)所有這些后,您可以運(yùn)行 jited 函數(shù):

from numba import jit

import numpy as np


inflow = np.arange(1,0,-0.01)  # Dim [T]

actions = np.arange(0,1,0.05)  # Dim [M]

start_lvl = np.random.rand(500).reshape(-1,1)*49  # Dim [Nx1]

disc_lvl = np.arange(0,1000)  # Dim [O]


@jit(nopython=True)

def my_func(disc_lvl, actions, start_lvl, inflow):

    for i in range(0,100):

        # Calculate new level at time i

        new_lvl = start_lvl + inflow[i] + actions  # Dim [N x M]


        # For each new_level element, find closest discretized level

        new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)

        diff = np.abs(disc_lvl - new_lvl_3d)  # Dim [N x M x O]


        idx_lvl = np.empty(new_lvl.shape)

        for i in range(diff.shape[0]):

            for j in range(diff.shape[1]):

                idx_lvl[i, j] = diff[i, j, :].argmin()


        return True


# function works fine without numba

success = my_func(disc_lvl, actions, start_lvl, inflow)


查看完整回答
反對(duì) 回復(fù) 2022-10-18
?
翻過(guò)高山走不出你

TA貢獻(xiàn)1875條經(jīng)驗(yàn) 獲得超3個(gè)贊

在我的第一篇文章的更正代碼下方找到,您可以在使用和不使用 numba 庫(kù)的 jitted 模式的情況下執(zhí)行(通過(guò)刪除以 @jit 開(kāi)頭的行)。我觀察到這個(gè)例子的速度增加了 2 倍。


from numba import jit

import numpy as np

import datetime as dt


inflow = np.arange(1,0,-0.01)                       # Dim [T]

nbTime = np.shape(inflow)[0]

actions = np.arange(0,1,0.01)                       # Dim [M]

start_lvl = np.random.rand(500).reshape(-1,1)*49    # Dim [Nx1]

disc_lvl = np.arange(0,1000)                        # Dim [O]


@jit(nopython=True)

def my_func(nbTime, disc_lvl, actions, start_lvl, inflow):

    # Initialize result 

    res = np.empty((nbTime,np.shape(start_lvl)[0],np.shape(actions)[0]))


    for t in range(0,nbTime):

        # Calculate new level at time t

        new_lvl = start_lvl + inflow[t] + actions  # Dim [N x M]      

        print(t)


        # For each new_level element, find closest discretized level

        new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)

        diff = np.abs(disc_lvl - new_lvl_3d)  # Dim [N x M x O]


        idx_lvl = np.empty(new_lvl.shape)

        for i in range(diff.shape[0]):

            for j in range(diff.shape[1]):

                idx_lvl[i, j] = diff[i, j, :].argmin()


        res[t,:,:] = idx_lvl


    return res


# Call function and print running time

start_time = dt.datetime.now()

result = my_func(nbTime, disc_lvl, actions, start_lvl, inflow)

print('Execution time :',(dt.datetime.now() - start_time))


查看完整回答
反對(duì) 回復(fù) 2022-10-18
  • 2 回答
  • 0 關(guān)注
  • 143 瀏覽
慕課專欄
更多

添加回答

舉報(bào)

0/150
提交
取消
微信客服

購(gòu)課補(bǔ)貼
聯(lián)系客服咨詢優(yōu)惠詳情

幫助反饋 APP下載

慕課網(wǎng)APP
您的移動(dòng)學(xué)習(xí)伙伴

公眾號(hào)

掃描二維碼
關(guān)注慕課網(wǎng)微信公眾號(hào)