首先創建了一些數據:
import numpy as np
np.random.seed(1)
list_of_np_1D = np.random.randint(0, 5, size=(500, 6))
np_2D = np.random.randint(0, 5, size=(20, 6))
運行代碼:
%%time
fill_set = set()
for i in list_of_np_1D:
vecs = i + np_2D
for v in vecs:
tup = tuple(v)
fill_set.add(tup)
res1 = np.array(list(fill_set))
輸出:
CPU times: user 161 ms, sys: 2 ms, total: 163 ms
Wall time: 167 ms
這裏是一個加速版本,它使用廣播,.view()
方法來轉換dtype轉換爲字符串後,調用set()
將字符串轉換回ar光線:
%%time
r = list_of_np_1D[:, None, :] + np_2D[None, :, :]
stype = "S%d" % (r.itemsize * np_2D.shape[1])
fill_set2 = set(r.ravel().view(stype).tolist())
res2 = np.zeros(len(fill_set2), dtype=stype)
res2[:] = list(fill_set2)
res2 = res2.view(r.dtype).reshape(-1, np_2D.shape[1])
輸出:
CPU times: user 13 ms, sys: 1 ms, total: 14 ms
Wall time: 14.6 ms
檢查結果:
np.all(res1[np.lexsort(res1.T), :] == res2[np.lexsort(res2.T), :])
您還可以使用lexsort()
刪除重複的數據:
%%time
r = list_of_np_1D[:, None, :] + np_2D[None, :, :]
r = r.reshape(-1, r.shape[-1])
r = r[np.lexsort(r.T)]
idx = np.where(np.all(np.diff(r, axis=0) == 0, axis=1))[0] + 1
res3 = np.delete(r, idx, axis=0)
輸出:
CPU times: user 13 ms, sys: 3 ms, total: 16 ms
Wall time: 16.1 ms
檢查結果:
np.all(res1[np.lexsort(res1.T), :] == res3)
不僅是NumPy的陣列不是哈希的,他們不是就算真的* equatable *。如果'a'或'b'是一個數組,'a'不會產生一個布爾值來表示'a'是否等於'b'','set'不知道如何處理一個元素數組比較結果或如何調用'np.array_equal'。 – user2357112
你真的需要將你的數組轉換成Python集嗎? Numpy本地支持數組上的各種設置操作(參見['numpy.lib.arraysetops'](http://docs.scipy.org/doc/numpy/reference/routines.set.html))。 –
@ali_m我沒有意識到這一點,我現在就去看看。最終,我有兩個大的1D整數數組的集合,我需要能夠添加更多的數組到這些集合,並做一些等同於'.difference_update'操作的設置。 – CBowman