2017-08-08 29 views
1

我想用numpy爲一些統計對象製作集合字典,簡化狀態如下。使用numpy爲某些統計對象製作集合字典

有分別標量陣列記爲 a = np.array([n1,n2,n3...]) 和2D陣列作爲 b = np.array([[q1_1,q1_2],[q2_1,q2_2],[q3_1,q3_2]...])

對於每個元素nia,我要挑選出所有包含nib元素qi([qi_1,qi_2])和使用key作爲ni來收集它們的dict

我記錄了笨拙方法用於此目的(假設ab被確定)爲以下代碼爲:

import numpy as np 

a = np.array([i+1 for i in range(100)]) 
b = np.array([[2*i+1,2*(i+1)] for i in range(50)]) 
dict = {} 
for i in a: dict[i] = [j for j in b if i in j] 

毫無疑問,當ab都很大,這將是非常慢。 有沒有其他有效的方法來取代上面的一個? 尋求你的幫助!

回答

0

numpy的陣列允許的elementwise比較:

equal = b[:,:,np.newaxis]==a #np.newaxis to broadcast 
# if one of the two is equal, we will include this element 
index = np.logical_or(equal[:,0], equal[:,1]) 
# indexing by a boolean array to get the result 
dictionary = {i: b[index[:,i]] for i in range(len(a))} 

作爲最後的一句話:你確定要使用字典?通過這樣你就失去了很多numpy的優點

編輯,回答你的評論: 用a和b這個大,等於10^10的大小,這就使得8 * 10^10個字節,大約是72 G.這就是你得到這個錯誤的原因。 你應該問的主要問題是:我真的需要這個大數組嗎?如果是的話,你確定,詞典不會太大?

這個問題可以通過一次不計算everythin來解決,但在n次,n應該是大約72/16(在內存中的比例)在你的情況下。然而具有n點點較大可能會加快這一進程:

stride = int(len(b)/n) 
dictionary = {} 
for i in range(n): 
    #splitting b into several parts 
    equal = b[n*stride:(n+1)*stride,:,np.newaxis]==a 
    index = np.logical_or(equal[:,0], equal[:,1]) 
    dictionary.update({i: b[index[:,i]] for i in range(len(a))}) 
+0

感謝這個建議,你的代碼比我的速度快得多,但當a和b很大時(例如a = 100000和b = 50000,RAM爲16G),仍然會給'記憶錯誤',而我的這種情況下的代碼不再有效。如何提高你的代碼來阻止那個錯誤? – zgfu1985

+0

@ zgfu1985我調整了代碼,這應該解決問題,想想你是否真的需要這麼大的數組 –

0

感謝您的想法。它可以完全解決我的問題。你的核心概念是對a和b進行比較,並得到布爾數組。因此,使用此布爾值索引爲數組b創建字典非常快。按照這個思路,我重寫你的代碼在我自己的方式爲

dict = {} 
for item in a: 
    index_left, index_right = (b[:,0]==item), (b[:,1]==item) 
    index = np.logical_or(index_left, index_right) 
    dict[item] = dict[index] 

這些代碼仍然不是比你快,但能避免,即使在大的a和b的「存儲器的差錯」(如:A = 100000, b = 200000)