2017-03-12 70 views
0

我的數據樣本每個都是一個numpy的形狀數組,例如, (100,100,9),我將它們中的10個連接成形狀(10,100,100,9)的單個陣列foo。在10個數據樣本中,我想找到重複值的索引。例如,如果foo[0, 42, 42, 3] = 0.72foo[0, 42, 42, 7] = 0.72,我想要一個反映這一點的輸出。什麼是有效的方式?我想在形狀(100,100,9)的布爾輸出數組,但有沒有比循環比較每個數據樣本(數據樣本(10)的數量的二次運行時)更好的方法?在numpy nd數組中找到重複值

+0

你只是想標記任何具有重複值的值,或者是否想要一個以數據值爲關鍵字的字典,並將索引作爲字典值進行復制? – James

+0

@詹姆斯問題是左通用沒有指定確切的數據返回,以便不約束可能的解決方案,但我想一個布爾數組,通過索引簡單地標記重複項(如上建議)。 – BoltzmannBrain

回答

0

在下面的代碼段中,dups是所期望的結果:一個布爾陣列示出了索引重複。還有一個delta閾值,因此值< =該閾值的任何差異都是重複的。

delta = 0. 
dups = np.zeros(foo.shape[1:], dtype=bool) 
for i in xrange(foo.shape[0]): 
    for j in xrange(foo.shape[0]): 
     if i==j: continue 
     dups += abs(foo[i] - foo[j]) <= delta 
-1

以下是對每個樣品使用argsort的解決方案。不漂亮,但速度不快,但能完成這項工作。

import numpy as np 
from timeit import timeit 

def dupl(a, axis=0, make_dict=True): 
    a = np.moveaxis(a, axis, -1) 
    i = np.argsort(a, axis=-1, kind='mergesort') 
    ai = a[tuple(np.ogrid[tuple(map(slice, a.shape))][:-1]) + (i,)] 
    same = np.zeros(a.shape[:-1] + (a.shape[-1]+1,), bool) 
    same[..., 1:-1] = np.diff(ai, axis=-1) == 0 
    uniqs = np.where((same[..., 1:] & ~same[..., :-1]).ravel())[0] 
    same = (same[...,1:]|same[...,:-1]).ravel() 
    reps = np.split(i.ravel()[same], np.cumsum(same)[uniqs[1:]-1]) 
    grps = np.searchsorted(uniqs, np.arange(0, same.size, a.shape[-1])) 
    keys = ai.ravel()[uniqs] 
    if make_dict: 
     result = np.empty(a.shape[:-1], object) 
     result.ravel()[:] = [dict(zip(*p)) for p in np.split(
       np.array([keys, reps], object), grps[1:], axis=-1)] 
     return result 
    else: 
     return keys, reps, grps 

a = np.random.randint(0,10,(10,100,100,9)) 
axis = 0 
result = dupl(a, axis) 

print('shape, axis, time (sec) for 10 trials:', 
     a.shape, axis, timeit(lambda: dupl(a, axis=axis), number=10)) 
print('same without creating dict:', 
     a.shape, axis, timeit(lambda: dupl(a, axis=axis, make_dict=False), 
          number=10)) 

#check 
print("checking result") 
am = np.moveaxis(a, axis, -1) 
for af, df in zip(am.reshape(-1, am.shape[-1]), result.ravel()): 
    assert len(set(af)) + sum(map(len, df.values())) == len(df) + am.shape[-1] 
    for k, v in df.items(): 
     assert np.all(np.where(af == k)[0] == v) 
print("no errors") 

打印:

shape, axis, time (sec) for 10 trials: (10, 100, 100, 9) 0 5.328339613042772 
same without creating dict: (10, 100, 100, 9) 0 2.568383438978344 
checking result 
no errors 
+0

在這段代碼中有各種各樣的味道,而且有很多雜項排序,就像這樣試圖效率低下。 – BoltzmannBrain

+0

@BoltzmannBrain有點苛刻你不覺得嗎?除了與你的不同之外,這具有合理的複雜度,而不是O(n k^2),而是O(n(log n/k + k log k))。我承認這並不容易,但不要因爲它超出了你而摔跤。 –