2013-08-26 33 views
10

我在幾個for循環內使用numpy的函數很多次,但它變得太慢了。有什麼方法可以更快地執行此功能?我讀過你應該嘗試做in-line for循環,並且在for循環之前爲函數創建局部變量,但似乎沒有什麼能夠提高速度(< 1%)。 len(UNIQ_IDS)〜800. emiss_dataobj_data是形狀=(2600,5200)的numpy ndarrays。我已經使用import profile來處理瓶頸的位置,並且for環路中的where是一個很大的問題。快python numpy哪裏有功能?

import numpy as np 
max = np.max 
where = np.where 
MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS)] 
+0

你在這個腳本中計算'UNIQ_IDS'還是預先確定的? – Daniel

+0

UNIQ_IDS是預先確定的... len = 800的整數列表。這只是一個代碼片段,對於混淆抱歉。 –

回答

2

你就不能這樣做

emiss_data[obj_data == i] 

?我不知道你爲什麼使用where

+0

那麼確實有效,並且提高了約45%。謝謝。我想我正在使用,因爲我習慣了IDL,並且正在嘗試轉換爲python。但是,它仍然非常緩慢。完成800次需要75秒,而IDL則需要2秒完成。如果您確實需要未來業務的地點/指數呢?如果您在for循環中多次使用它,而不是for循環中的where語句,我不會想象這會非常高效。 –

+0

看來應該有一種方法可以通過'obj_data'值與numpy內置組件組合'emiss_data'值。儘管如此,我還沒有找到。 – user2357112

+0

你可以使用'np.lexsort';然而,'lexsort'本身是導致次優解決方案的瓶頸。 – Daniel

0

根據Are tuples more efficient than lists in Python?,分配元組比分配列表要快得多,所以也許只需構建一個元組而不是一個列表,這樣可以提高效率。

+1

我懷疑它。在某些情況下元組有優勢,但這些元組在這裏都不適用。那個問題(或者說那裏的接受答案)並不表明元組的構造更快,它表明* literal *元組可以被構造一次並被多次使用。即使元組創建*的速度比創建列表的速度快,也不存在瓶頸。 – delnan

+0

感謝您的信息! – Jblasco

7

事實證明,在這種情況下,純Python循環比NumPy索引(或調用np.where)要快得多。

考慮以下選擇:

import numpy as np 
import collections 
import itertools as IT 

shape = (2600,5200) 
# shape = (26,52) 
emiss_data = np.random.random(shape) 
obj_data = np.random.random_integers(1, 800, size=shape) 
UNIQ_IDS = np.unique(obj_data) 

def using_where(): 
    max = np.max 
    where = np.where 
    MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS] 
    return MAX_EMISS 

def using_index(): 
    max = np.max 
    MAX_EMISS = [max(emiss_data[obj_data == i]) for i in UNIQ_IDS] 
    return MAX_EMISS 

def using_max(): 
    MAX_EMISS = [(emiss_data[obj_data == i]).max() for i in UNIQ_IDS] 
    return MAX_EMISS 

def using_loop(): 
    result = collections.defaultdict(list) 
    for val, idx in IT.izip(emiss_data.ravel(), obj_data.ravel()): 
     result[idx].append(val) 
    return [max(result[idx]) for idx in UNIQ_IDS] 

def using_sort(): 
    uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 
    vals = uind.argsort() 
    count = np.bincount(uind) 
    start = 0 
    end = 0 
    out = np.empty(count.shape[0]) 
    for ind, x in np.ndenumerate(count): 
     end += x 
     out[ind] = np.max(np.take(emiss_data, vals[start:end])) 
     start += x 
    return out 

def using_split(): 
    uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 
    vals = uind.argsort() 
    count = np.bincount(uind) 
    return [np.take(emiss_data, item).max() 
      for item in np.split(vals, count.cumsum())[:-1]] 

for func in (using_index, using_max, using_loop, using_sort, using_split): 
    assert using_where() == func() 

下面是基準,與shape = (2600,5200)

In [57]: %timeit using_loop() 
1 loops, best of 3: 9.15 s per loop 

In [90]: %timeit using_sort() 
1 loops, best of 3: 9.33 s per loop 

In [91]: %timeit using_split() 
1 loops, best of 3: 9.33 s per loop 

In [61]: %timeit using_index() 
1 loops, best of 3: 63.2 s per loop 

In [62]: %timeit using_max() 
1 loops, best of 3: 64.4 s per loop 

In [58]: %timeit using_where() 
1 loops, best of 3: 112 s per loop 

因此using_loop(純Python)原來是超過11倍比using_where更快。

我不完全確定爲什麼純Python在這裏比NumPy快。我的猜測是,純Python版本通過兩個數組拉一次(是的,雙關)。它利用了這樣一個事實,即儘管所有的花式索引,我們真的只想訪問每個值。因此,它必須確定emiss_data中的每個值究竟屬於哪個組,但是這只是模糊的猜測。我不知道在我進行基準測試之前它會更快。

+0

using_loop中的'list'是什麼? –

+0

['collections.defaultdict(list)'](http://docs.python.org/2/library/collections.html#collections.defaultdict)創建一個類似dict的對象,它返回一個列表作爲默認值。 – unutbu

7

可以使用np.uniquereturn_index

def using_sort(): 
    #UNIQ_IDS,uind=np.unique(obj_data, return_inverse=True) 
    uind= np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 
    vals=uind.argsort() 
    count=np.bincount(uind) 

    start=0 
    end=0 

    out=np.empty(count.shape[0]) 
    for ind,x in np.ndenumerate(count): 
     end+=x 
     out[ind]=np.max(np.take(emiss_data,vals[start:end])) 
     start+=x 
    return out 

使用@ unutbu的答案爲基準shape = (2600,5200)

np.allclose(using_loop(),using_sort()) 
True 

%timeit using_loop() 
1 loops, best of 3: 12.3 s per loop 

#With np.unique inside the definition 
%timeit using_sort() 
1 loops, best of 3: 9.06 s per loop 

#With np.unique outside the definition 
%timeit using_sort() 
1 loops, best of 3: 2.75 s per loop 

#Using @Jamie's suggestion for uind 
%timeit using_sort() 
1 loops, best of 3: 6.74 s per loop 
+2

我想如果'UNIQ_IDS'確實有預先計算好的'obj_data'的唯一條目,那麼你可以在大約一半的時間內調用'np.digitize(obj_data,UNIQ_IDS) - 1'獲得與你的'uind'相同的結果。 – Jaime

+0

您的方法非常聰明,但不幸的是我無法獲得相同的速度增益。 (我在我的機器上運行時在我的機器上運行時添加了'using_sort'的基準測試,對我來說''using_loop''仍然稍微快一些。)或許這是由於Python的版本或OS?我在Ubuntu 11.10上使用Python 2.7。你在用什麼? – unutbu

+0

@unutbu我使用的是OSX和完全更新的anaconda安裝(它的確有加速,我知道它過去搞砸了時間)。我也在OSX盒子上用Python 2.7.4和numpy 1.7.1嘗試過,並且獲得了相同的結果;然而,我嘗試了一款搭載numpy 1.6.1的AMD芯片的Ubuntu機箱,並發現它們的時序相同。我討厭繼續張貼[這](http://stackoverflow.com/questions/18365073/why-is-numpys-einsum-faster-than-numpys-built-in-functions)的問題,但似乎有東西會去我不明白的時間安排。 – Daniel

5

我相信做到這一點的最快方法是使用在該groupby()操作pandas包。到@比較俄菲翁的using_sort()功能,大熊貓是關於10快的一個因素:

import numpy as np 
import pandas as pd 

shape = (2600,5200) 
emiss_data = np.random.random(shape) 
obj_data = np.random.random_integers(1, 800, size=shape) 
UNIQ_IDS = np.unique(obj_data) 

def using_sort(): 
    #UNIQ_IDS,uind=np.unique(obj_data, return_inverse=True) 
    uind= np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 
    vals=uind.argsort() 
    count=np.bincount(uind) 

    start=0 
    end=0 

    out=np.empty(count.shape[0]) 
    for ind,x in np.ndenumerate(count): 
     end+=x 
     out[ind]=np.max(np.take(emiss_data,vals[start:end])) 
     start+=x 
    return out 

def using_pandas(): 
    return pd.Series(emiss_data.ravel()).groupby(obj_data.ravel()).max() 

print('same results:', np.allclose(using_pandas(), using_sort())) 
# same results: True 

%timeit using_sort() 
# 1 loops, best of 3: 3.39 s per loop 

%timeit using_pandas() 
# 1 loops, best of 3: 397 ms per loop 
0

如果obj_data由相對較小的整數,則可以使用numpy.maximum.at(自v1.8.0):

def using_maximumat(): 
    n = np.max(UNIQ_IDS) + 1 
    temp = np.full(n, -np.inf) 
    np.maximum.at(temp, obj_data, emiss_data) 
    return temp[UNIQ_IDS]