2016-09-04 40 views
1

我需要基於具有類成員關係信息的另一個數組(labels)對1D numpy數組(data)中的元素進行總結。我在下面的代碼中使用numba來加速它。但是,如果我斑點沒有明確的線ret[int(find(labels, g))] += y投與int(),我reveice的錯誤消息:使用numba對numpy數組進行索引時的TypeError

TypeError: unsupported array index type ?int64

有沒有更好的解決方法是顯式轉換?

import numpy as np 
from numba import jit 

labels = np.array([45, 85, 99, 89, 45, 86, 348, 764]) 
n = int(1e3) 
data = np.random.random(n) 
groups = np.random.choice(a=labels, size=n, replace=True) 

@jit(nopython=True) 
def find(seq, value): 
    for ct, x in enumerate(seq): 
     if x == value: 
      return ct 

@jit(nopython=True) 
def subsumNumba(data, groups, labels): 
    ret = np.zeros(len(labels)) 
    for y, g in zip(data, groups): 
     # not working without casting with int() 
     ret[int(find(labels, g))] += y 
    return ret 
+0

此代碼與我的機器上的Numba 0.28.1一起使用時沒有錯誤。你使用的是哪個版本的Numba。另外作爲一個附註,你可能想要避免使用'zip'和'enumerate'並明確使用索引計數器出於性能原因。你必須測試一下,看看它是否對你的用例產生了影響,但在過去,根據我的經驗,它確實如此。 – JoshAdel

+0

@JoshAdel我有版本0.26.0(將嘗試更新現在)。你的意思是代碼在你的機器上沒有* int()強制轉換? – NoBackingDown

+0

@JoshAdel它沒有枚舉測試函數'find',性能增益最小。進一步優化代碼時,我會牢記它。 – NoBackingDown

回答

1

的問題是,find可以返回一個intNone如果它沒有發現任何東西,所以我覺得?int64錯誤。爲了避免投射,當find退出時,您需要提供int返回值,但不會找到所需的值,然後在調用者中處理它。

+0

就是這樣!我沒有想到我,因爲'find'保證找到我的問題的結構。現在我只是返回一個虛擬整數的理論情況下,沒有擊中,它的作品! – NoBackingDown