2016-02-16 54 views
2

我知道,爲了添加一個元素到一個集合,它必須是可散列的,而numpy數組看起來並不是。這是造成我一些問題,因爲我有以下的代碼位:可能添加numpy數組到Python集?

fill_set = set() 
for i in list_of_np_1D: 
    vecs = i + np_2D 
    for j in range(N): 
     tup = tuple(vecs[j,:]) 
     fill_set.add(tup) 

# list_of_np_1D is a list of 1D numpy arrays 
# np_2D is a 2D numpy array 
# np_2D could also be converted to a list of 1D arrays if it helped. 

我需要得到這個運行速度更快,近50%的運行時間都花在轉換2D numpy的陣列的切片的元組等等他們可以被添加到集合中。

所以我一直在試圖找出以下

  • 是否有像numpy的陣列功能(具有向量加法)哈希的,使他們能夠加入到組沒有辦法讓numpy的陣列,或一些?
  • 如果不是,有沒有一種方法可以加快進行元組轉換的過程?

感謝您的幫助!

+1

不僅是NumPy的陣列不是哈希的,他們不是就算真的* equatable *。如果'a'或'b'是一個數組,'a'不會產生一個布爾值來表示'a'是否等於'b'','set'不知道如何處理一個元素數組比較結果或如何調用'np.array_equal'。 – user2357112

+2

你真的需要將你的數組轉換成Python集嗎? Numpy本地支持數組上的各種設置操作(參見['numpy.lib.arraysetops'](http://docs.scipy.org/doc/numpy/reference/routines.set.html))。 –

+1

@ali_m我沒有意識到這一點,我現在就去看看。最終,我有兩個大的1D整數數組的集合,我需要能夠添加更多的數組到這些集合,並做一些等同於'.difference_update'操作的設置。 – CBowman

回答

2

首先創建了一些數據:

import numpy as np 
np.random.seed(1) 
list_of_np_1D = np.random.randint(0, 5, size=(500, 6)) 
np_2D = np.random.randint(0, 5, size=(20, 6)) 

運行代碼:

%%time 
fill_set = set() 
for i in list_of_np_1D: 
    vecs = i + np_2D 
    for v in vecs: 
     tup = tuple(v) 
     fill_set.add(tup) 
res1 = np.array(list(fill_set)) 

輸出:

CPU times: user 161 ms, sys: 2 ms, total: 163 ms 
Wall time: 167 ms 

這裏是一個加速版本,它使用廣播,.view()方法來轉換dtype轉換爲字符串後,調用set()將字符串轉換回ar光線:

%%time 
r = list_of_np_1D[:, None, :] + np_2D[None, :, :] 
stype = "S%d" % (r.itemsize * np_2D.shape[1]) 
fill_set2 = set(r.ravel().view(stype).tolist()) 
res2 = np.zeros(len(fill_set2), dtype=stype) 
res2[:] = list(fill_set2) 
res2 = res2.view(r.dtype).reshape(-1, np_2D.shape[1]) 

輸出:

CPU times: user 13 ms, sys: 1 ms, total: 14 ms 
Wall time: 14.6 ms 

檢查結果:

np.all(res1[np.lexsort(res1.T), :] == res2[np.lexsort(res2.T), :]) 

您還可以使用lexsort()刪除重複的數據:

%%time 
r = list_of_np_1D[:, None, :] + np_2D[None, :, :] 
r = r.reshape(-1, r.shape[-1]) 

r = r[np.lexsort(r.T)] 
idx = np.where(np.all(np.diff(r, axis=0) == 0, axis=1))[0] + 1 
res3 = np.delete(r, idx, axis=0) 

輸出:

CPU times: user 13 ms, sys: 3 ms, total: 16 ms 
Wall time: 16.1 ms 

檢查結果:

np.all(res1[np.lexsort(res1.T), :] == res3)