事實證明,在這種情況下,純Python循環比NumPy索引(或調用np.where)要快得多。
考慮以下選擇:
import numpy as np
import collections
import itertools as IT
shape = (2600,5200)
# shape = (26,52)
emiss_data = np.random.random(shape)
obj_data = np.random.random_integers(1, 800, size=shape)
UNIQ_IDS = np.unique(obj_data)
def using_where():
max = np.max
where = np.where
MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS]
return MAX_EMISS
def using_index():
max = np.max
MAX_EMISS = [max(emiss_data[obj_data == i]) for i in UNIQ_IDS]
return MAX_EMISS
def using_max():
MAX_EMISS = [(emiss_data[obj_data == i]).max() for i in UNIQ_IDS]
return MAX_EMISS
def using_loop():
result = collections.defaultdict(list)
for val, idx in IT.izip(emiss_data.ravel(), obj_data.ravel()):
result[idx].append(val)
return [max(result[idx]) for idx in UNIQ_IDS]
def using_sort():
uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1
vals = uind.argsort()
count = np.bincount(uind)
start = 0
end = 0
out = np.empty(count.shape[0])
for ind, x in np.ndenumerate(count):
end += x
out[ind] = np.max(np.take(emiss_data, vals[start:end]))
start += x
return out
def using_split():
uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1
vals = uind.argsort()
count = np.bincount(uind)
return [np.take(emiss_data, item).max()
for item in np.split(vals, count.cumsum())[:-1]]
for func in (using_index, using_max, using_loop, using_sort, using_split):
assert using_where() == func()
下面是基準,與shape = (2600,5200)
:
In [57]: %timeit using_loop()
1 loops, best of 3: 9.15 s per loop
In [90]: %timeit using_sort()
1 loops, best of 3: 9.33 s per loop
In [91]: %timeit using_split()
1 loops, best of 3: 9.33 s per loop
In [61]: %timeit using_index()
1 loops, best of 3: 63.2 s per loop
In [62]: %timeit using_max()
1 loops, best of 3: 64.4 s per loop
In [58]: %timeit using_where()
1 loops, best of 3: 112 s per loop
因此using_loop
(純Python)原來是超過11倍比using_where
更快。
我不完全確定爲什麼純Python在這裏比NumPy快。我的猜測是,純Python版本通過兩個數組拉一次(是的,雙關)。它利用了這樣一個事實,即儘管所有的花式索引,我們真的只想訪問每個值。因此,它必須確定emiss_data
中的每個值究竟屬於哪個組,但是這只是模糊的猜測。我不知道在我進行基準測試之前它會更快。
你在這個腳本中計算'UNIQ_IDS'還是預先確定的? – Daniel
UNIQ_IDS是預先確定的... len = 800的整數列表。這只是一個代碼片段,對於混淆抱歉。 –