2016-11-16 73 views
1

我有一個大型的numpy 1-d,包含大約700,000個類。另外,我還有另一個類似大小的數組,它包含這些類的新值。矢量化numpy 1-d重新分類

例陣列

original_classes = np.array([0,1,2,3,4,5,6,7,8,9,10,10]) 
new_classes = np.array([1,0,1,2,2,10,1,6,6,9,5,12]) 

希望的輸出

>>> reclassify_function(original_classes, new_classes) 
array([ 1, 1, 1, 1, 1, 12, 1, 1, 9, 12, 12]) 

的困難是,有多個階級關係。

原始類別1應該得到一個新的值0,這意味着0和1是相等的類別,並且這些值的所有出現都應該被分配到相同的新類別號碼。原始類別2應歸類爲1,這意味着類別2等於類別0和1.因此原始類別0-2應分配到相同的新類別編號等等。

由於我在工作與巨大的數組我想重新分類函數被矢量化。

+1

70,000個物品的數組?它並不是很大。嘗試使用循環的實現。如果你有這個權利,並且如果你對錶現不滿意,請在此發佈。 – Balzola

+0

對不起,應該是70萬。已經在循環方法上工作了! –

+1

這將佔用大約2.7MB的內存。仍然不是那麼大。 – Balzola

回答

1

您可以使用scipy.sparse.csgraph.connected_components重新標記您的類。爲了您的數據。例如:

from scipy.sparse import csr_matrix 
from scipy.sparse.csgraph import connected_components 

A = np.array([0,1,2,3,4,5, 6,7,8,9,10,10]) 
B = np.array([1,0,1,2,2,10,1,6,6,9,5 ,12]) 

N = max(A.max(), B.max()) + 1 
weights = np.ones(len(A), int) 
graph = csr_matrix((weights, (A, B)), shape=(N, N)) 
n_remaining, mapping = connected_components(graph, directed=False) 
print mapping[A] 

給出:

[0 0 0 0 0 1 0 0 0 2 1 1] 

這些都是重新標記的類。我相信你可以找出如何用輸入數據來表達這些信息。注意爲了獲得最佳性能,「原始」和「新」類應該是沒有間隙的單個範圍的連續整數。

+0

不錯!即使使用非連續值,該方法也會正確重新標記類。你確定他們需要連續? –

+1

@Wilmar - 這是最好的表現。在我的代碼中,「映射」將是一個長度爲「N」的數組,因此如果輸入具有數十億未使用的分類,則它將是比必要的更大的數組,因此效率低下。一些小差距不會對性能造成太大影響。但是,那麼'n_remaining'也包含了未使用的類,'映射[A]'最不可能是連續的。 – user7138814

0

這不是一個矢量化的解決方案,並且在我的筆記本上花了大約一個小時。這將創建一個列表集呼叫class_sets;每個集合都是等價類的集合。

original_classes = np.random.randint(0,20000,700000) 
new_classes = np.random.randint(0,20000,700000) 
pairs = zip(original_classes, new_classes) 
class_sets = [set(next(pairs))] 

for i,p in enumerate(pairs): 
    ps = set(p) 
    intsect = [ps.intersection(cs) for cs in class_sets] 
    if any([ps.intersection(cs) for cs in class_sets]): 
     index = np.argmax(intsect) 
     class_sets[index] = class_sets[index].union(ps) 
    else: 
     class_sets.append(ps) 
+0

謝謝,明天我會測試我的方法並將其與您的結果進行比較! –

+0

你的方法給出了一個錯誤,所以我沒有測試它 –