2013-04-16 134 views
1

我想優化下面的代碼,可能通過在Cython中重寫它:它只需要一個低維但相對較長的numpy數組,查看其值爲0的列,並將它們標記爲-1。代碼是:優化索引和檢索Python中numpy數組中的元素?

import numpy as np 

def get_data(): 
    data = np.array([[1,5,1]] * 5000 + [[1,0,5]] * 5000 + [[0,0,0]] * 5000) 
    return data 

def get_cols(K): 
    cols = np.array([2] * K) 
    return cols 

def test_nonzero(data): 
    K = len(data) 
    result = np.array([1] * K) 
    # Index into columns of data 
    cols = get_cols(K) 
    # Mark zero points with -1 
    idx = np.nonzero(data[np.arange(K), cols] == 0)[0] 
    result[idx] = -1 

import time 
t_start = time.time() 
data = get_data() 
for n in range(5000): 
    test_nonzero(data) 
t_end = time.time() 
print (t_end - t_start) 

data是數據。 cols是查找非零值的數據列數組(爲簡單起見,我將它全部放在同一列中)。我們的目標是計算一個numpy數組,result,其中感興趣列非零的每行的值爲1,並且感興趣的相應列的值爲零的行的值爲-1。

在15,000行3列不太大的數組上運行5000次大約需要20秒。有沒有辦法可以加快速度?看起來大部分工作都是尋找非零元素並用索引檢索它們(調用nonzero並隨後使用它的索引)。這可以優化嗎?或者這是最好的可以完成的嗎? Cython實現如何在這方面獲得更快的速度?

+0

非零是一個很好的嘗試(不知道它是否幫助很大或所有雖然)。如果你絕望並知道cols是有效的,你可以嘗試製作一個線性索引。如果K在循環中不變,則不能每次都重做np.arange ... – seberg

回答

2
cols = np.array([2] * K) 

這會很慢。這是創建一個非常大的Python列表,然後將其轉換爲一個numpy數組。相反,這樣做:

cols = np.ones(K, int)*2 

那將是這樣快

result = np.array([1] * K) 

在這裏,你應該做的:

result = np.ones(K, int) 

這將直接產生numpy的陣列。

idx = np.nonzero(data[np.arange(K), cols] == 0)[0] 
result[idx] = -1 

cols是一個數組,但您可以傳遞一個2.此外,使用非零增加一個額外的步驟。

idx = data[np.arange(K), 2] == 0 
result[idx] = -1 

應該有同樣的效果。

+0

感謝但cols不是計算的一部分,它只是一個示例。正如我所提到的,cols並不總是像2這樣的一個價值 - 我只是爲了簡單而做出來的。它通常是一個向量將不同的列。所以我不認爲這些建議加快了速度 – user248237dfsf

+2

@ user248237dfsf,好吧,我錯過了關於cols不真實的部分。然而,'numpy.array([2] * K)'確實非常慢,並且它代碼的原因很慢。沒有看到你真實的代碼,我真的不知道爲什麼它會很慢。 –