1 回答

TA貢獻(xiàn)1818條經(jīng)驗(yàn) 獲得超7個(gè)贊
numba
根據(jù)測試結(jié)果,不太可能實(shí)現(xiàn)幾個(gè)數(shù)量級(jí)的改進(jìn)(不使用像甚至 Cython 這樣的底層工具)。這可以從執(zhí)行聚合計(jì)算所需的時(shí)間看出。
然而,仍然可以進(jìn)行兩個(gè)關(guān)鍵優(yōu)化:
減少顯式數(shù)據(jù)傳遞的數(shù)量 - 主要是
df[df['col'] = val]
過濾。在我的實(shí)現(xiàn)中,您的 for 循環(huán)被替換為(1)使用一次聚合所有內(nèi)容.groupby().agg()
,(2)使用查找表(dict)檢查閾值。我不確定是否存在更有效的方法,但它總是涉及一次數(shù)據(jù)傳遞,并且最多只能再節(jié)省幾秒鐘。訪問
df["col"].values
而不是df["col"]
盡可能。(注意,這不會(huì)復(fù)制數(shù)據(jù),因?yàn)榭梢栽?code>tracemalloc模塊打開的情況下輕松驗(yàn)證。)
基準(zhǔn)代碼:
使用您的示例生成了 15M 條記錄。
import pandas as pd
import numpy as np
from datetime import datetime
# check memory footprint
# import tracemalloc
# tracemalloc.start()
# data
df = pd.read_csv("/mnt/ramdisk/in.csv", index_col="idx")
del df['measurement_tstamp']
df.reset_index(drop=True, inplace=True)
df["travel_time_minutes"] = df["travel_time_minutes"].astype(np.float64)
# repeat
cols = df.columns
df = pd.DataFrame(np.repeat(df.values, 500000, axis=0))
df.columns = cols
# Aggregation starts
t0 = datetime.now()
print(f"Program begins....")
# 1. aggregate everything at once
df_agg = df.groupby("tmc_code").agg(
mode=("travel_time_minutes", pd.Series.mode),
q95=("travel_time_minutes", lambda x: np.quantile(x, .95))
)
t1 = datetime.now()
print(f" Aggregation: {(t1 - t0).total_seconds():.2f}s")
# 2. construct a lookup table for the thresholds
threshold = {}
for tmc_code, row in df_agg.iterrows(): # slow but only 1.2k rows
threshold[tmc_code] = np.max(row["mode"]) + row["q95"]
t2 = datetime.now() # doesn't matter
print(f" Computing Threshold: {(t2 - t1).total_seconds():.2f}s")
# 3. filtering
def f(tmc_code, travel_time_minutes):
return travel_time_minutes <= threshold[tmc_code]
df = df[list(map(f, df["tmc_code"].values, df["travel_time_minutes"].values))]
t3 = datetime.now()
print(f" Filter: {(t3 - t2).total_seconds():.2f}s...")
print(f"Program ends in {(datetime.now() - t0).total_seconds():.2f}s")
# memory footprint
# current, peak = tracemalloc.get_traced_memory()
# print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
# tracemalloc.stop()
print()
結(jié)果:(3 次運(yùn)行)
| No. | old | new | new(aggr) | new(filter) |
|-----|-------|-------|-----------|-------------|
| 1 | 24.55 | 14.04 | 9.87 | 4.16 |
| 2 | 23.84 | 13.58 | 9.66 | 3.92 |
| 3 | 24.81 | 14.37 | 10.02 | 4.34 |
| avg | 24.40 | 14.00 | | |
=> ~74% faster
使用 python 3.7 和 pandas 1.1.2 進(jìn)行測試
添加回答
舉報(bào)