我需要基於具有類成員關係信息的另一個數組(labels
)對1D numpy
數組(data
)中的元素進行總結。我在下面的代碼中使用numba
來加速它。但是,如果我斑點沒有明確的線ret[int(find(labels, g))] += y
投與int()
,我reveice的錯誤消息:使用numba對numpy數組進行索引時的TypeError
TypeError: unsupported array index type ?int64
有沒有更好的解決方法是顯式轉換?
import numpy as np
from numba import jit
labels = np.array([45, 85, 99, 89, 45, 86, 348, 764])
n = int(1e3)
data = np.random.random(n)
groups = np.random.choice(a=labels, size=n, replace=True)
@jit(nopython=True)
def find(seq, value):
for ct, x in enumerate(seq):
if x == value:
return ct
@jit(nopython=True)
def subsumNumba(data, groups, labels):
ret = np.zeros(len(labels))
for y, g in zip(data, groups):
# not working without casting with int()
ret[int(find(labels, g))] += y
return ret
此代碼與我的機器上的Numba 0.28.1一起使用時沒有錯誤。你使用的是哪個版本的Numba。另外作爲一個附註,你可能想要避免使用'zip'和'enumerate'並明確使用索引計數器出於性能原因。你必須測試一下,看看它是否對你的用例產生了影響,但在過去,根據我的經驗,它確實如此。 – JoshAdel
@JoshAdel我有版本0.26.0(將嘗試更新現在)。你的意思是代碼在你的機器上沒有* int()強制轉換? – NoBackingDown
@JoshAdel它沒有枚舉測試函數'find',性能增益最小。進一步優化代碼時,我會牢記它。 – NoBackingDown