3 回答

TA貢獻(xiàn)1804條經(jīng)驗 獲得超8個贊
numba
有了numba它可以優(yōu)化這兩個場景。從語法上講,您只需要構(gòu)造一個帶有簡單for循環(huán)的函數(shù):
from numba import njit
@njit
def get_first_index_nb(A, k):
for i in range(len(A)):
if A[i] > k:
return i
return -1
idx = get_first_index_nb(A, 0.9)
Numba通過JIT(“及時”)編譯代碼并利用CPU級別的優(yōu)化來提高性能。一個常規(guī)的 for無環(huán)路@njit裝飾通常會慢比你已經(jīng)嘗試了在條件滿足后期的情況下的方法。
對于Pandas數(shù)值系列df['data'],您可以簡單地將NumPy表示提供給JIT編譯的函數(shù):
idx = get_first_index_nb(df['data'].values, 0.9)
概括
由于numba允許將函數(shù)用作參數(shù),并且假設(shè)傳遞的函數(shù)也可以JIT編譯,則可以找到一種方法來計算第n個索引,其中滿足任意條件的條件func。
@njit
def get_nth_index_count(A, func, count):
c = 0
for i in range(len(A)):
if func(A[i]):
c += 1
if c == count:
return i
return -1
@njit
def func(val):
return val > 0.9
# get index of 3rd value where func evaluates to True
idx = get_nth_index_count(arr, func, 3)
對于第三個最后的值,可以喂相反,arr[::-1]和否定的結(jié)果len(arr) - 1,則- 1需要考慮0索引。
績效基準(zhǔn)
# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999
@njit
def get_first_index_nb(A, k):
for i in range(len(A)):
if A[i] > k:
return i
return -1
def get_first_index_np(A, k):
for i in range(len(A)):
if A[i] > k:
return i
return -1
%timeit get_first_index_nb(arr, m) # 375 ns
%timeit get_first_index_np(arr, m) # 2.71 μs
%timeit next(iter(np.where(arr > m)[0]), -1) # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1) # 2.5 μs
%timeit get_first_index_nb(arr, n) # 204 μs
%timeit get_first_index_np(arr, n) # 44.8 ms
%timeit next(iter(np.where(arr > n)[0]), -1) # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1) # 39.2 ms

TA貢獻(xiàn)1911條經(jīng)驗 獲得超7個贊
我也想做類似的事情,發(fā)現(xiàn)這個問題中提出的解決方案并沒有真正幫助我。特別是,numba對我來說,解決方案比問題本身中介紹的更常規(guī)的方法慢得多。我有一個times_all列表,通常為數(shù)萬個元素的數(shù)量級,并且想要找到第一個元素的索引times_all大于a 的索引time_event。而且我有數(shù)千個time_event。我的解決方案是將其times_all分成例如100個元素的塊,首先確定time_event屬于哪個時間段,保留該時間段的第一個元素的索引,然后找到該時間段中的哪個索引,然后將兩個索引相加。這是最少的代碼。對我來說,它的運(yùn)行速度比本頁中的其他解決方案快幾個數(shù)量級。
def event_time_2_index(time_event, times_all, STEPS=100):
import numpy as np
time_indices_jumps = np.arange(0, len(times_all), STEPS)
time_list_jumps = [times_all[idx] for idx in time_indices_jumps]
time_list_jumps_idx = next((idx for idx, val in enumerate(time_list_jumps)\
if val > time_event), -1)
index_in_jumps = time_indices_jumps[time_list_jumps_idx-1]
times_cropped = times_all[index_in_jumps:]
event_index_rel = next((idx for idx, val in enumerate(times_cropped) \
if val > time_event), -1)
event_index = event_index_rel + index_in_jumps
return event_index
添加回答
舉報