我的數據樣本每個都是一個numpy的形狀數組,例如, (100,100,9),我將它們中的10個連接成形狀(10,100,100,9)的單個陣列foo
。在10個數據樣本中,我想找到重複值的索引。例如,如果foo[0, 42, 42, 3] = 0.72
和foo[0, 42, 42, 7] = 0.72
,我想要一個反映這一點的輸出。什麼是有效的方式?我想在形狀(100,100,9)的布爾輸出數組,但有沒有比循環比較每個數據樣本(數據樣本(10)的數量的二次運行時)更好的方法?在numpy nd數組中找到重複值
0
A
回答
0
在下面的代碼段中,dups
是所期望的結果:一個布爾陣列示出了索引重複。還有一個delta
閾值,因此值< =該閾值的任何差異都是重複的。
delta = 0.
dups = np.zeros(foo.shape[1:], dtype=bool)
for i in xrange(foo.shape[0]):
for j in xrange(foo.shape[0]):
if i==j: continue
dups += abs(foo[i] - foo[j]) <= delta
-1
以下是對每個樣品使用argsort
的解決方案。不漂亮,但速度不快,但能完成這項工作。
import numpy as np
from timeit import timeit
def dupl(a, axis=0, make_dict=True):
a = np.moveaxis(a, axis, -1)
i = np.argsort(a, axis=-1, kind='mergesort')
ai = a[tuple(np.ogrid[tuple(map(slice, a.shape))][:-1]) + (i,)]
same = np.zeros(a.shape[:-1] + (a.shape[-1]+1,), bool)
same[..., 1:-1] = np.diff(ai, axis=-1) == 0
uniqs = np.where((same[..., 1:] & ~same[..., :-1]).ravel())[0]
same = (same[...,1:]|same[...,:-1]).ravel()
reps = np.split(i.ravel()[same], np.cumsum(same)[uniqs[1:]-1])
grps = np.searchsorted(uniqs, np.arange(0, same.size, a.shape[-1]))
keys = ai.ravel()[uniqs]
if make_dict:
result = np.empty(a.shape[:-1], object)
result.ravel()[:] = [dict(zip(*p)) for p in np.split(
np.array([keys, reps], object), grps[1:], axis=-1)]
return result
else:
return keys, reps, grps
a = np.random.randint(0,10,(10,100,100,9))
axis = 0
result = dupl(a, axis)
print('shape, axis, time (sec) for 10 trials:',
a.shape, axis, timeit(lambda: dupl(a, axis=axis), number=10))
print('same without creating dict:',
a.shape, axis, timeit(lambda: dupl(a, axis=axis, make_dict=False),
number=10))
#check
print("checking result")
am = np.moveaxis(a, axis, -1)
for af, df in zip(am.reshape(-1, am.shape[-1]), result.ravel()):
assert len(set(af)) + sum(map(len, df.values())) == len(df) + am.shape[-1]
for k, v in df.items():
assert np.all(np.where(af == k)[0] == v)
print("no errors")
打印:
shape, axis, time (sec) for 10 trials: (10, 100, 100, 9) 0 5.328339613042772
same without creating dict: (10, 100, 100, 9) 0 2.568383438978344
checking result
no errors
+0
在這段代碼中有各種各樣的味道,而且有很多雜項排序,就像這樣試圖效率低下。 – BoltzmannBrain
+0
@BoltzmannBrain有點苛刻你不覺得嗎?除了與你的不同之外,這具有合理的複雜度,而不是O(n k^2),而是O(n(log n/k + k log k))。我承認這並不容易,但不要因爲它超出了你而摔跤。 –
相關問題
- 1. numpy nd數組到熊貓列沒有[]
- 2. INSERT Numpy ND數組到MySQL表
- 3. 在numpy數組中查找重複值的索引
- 4. 將nD numpy數組合併到一維數組中
- 5. 在Numpy數組中找到多個值
- 6. 在numpy數組中找到缺失值
- 7. 在numpy數組中找到連續的重複nan
- 8. 在數組中找到重複數組
- 9. 從python計數器字典到nd numpy數組
- 10. 在numpy數組中高效查找值
- 11. 在Numpy數組中查找多個值
- 12. 找到數組中重複的數字
- 13. 如何在PHP數組中找到重複值?
- 14. 算法在數組中找到最重複的值
- 15. 在二維數組中找到重複的值3x3魔方
- 16. 在兩個數組中找到重複的值
- 17. 如何使用php在數組中找到重複值?
- 18. 使用drools在數組中找到重複的值
- 19. 在numpy數組中查找數組?
- 20. 在兩個numpy數組中找到最接近的值
- 21. numpy:替換重新數組中的值
- 22. 如何找到嵌套數組中的最大值nd更新它?
- 23. 如何找到並控制數組中的日誌重複值?
- 24. 爲數組中的圖像找到重複值
- 25. 關聯數組,找到值中的重複項
- 26. 查找重複值多維數組
- 27. 如何找到數組中最重複的值?
- 28. 如何找到重複的和最高值在數組
- 29. 找到重複的值在多維數組
- 30. 使用HASHSET在兩個數組之間找到重複的值
你只是想標記任何具有重複值的值,或者是否想要一個以數據值爲關鍵字的字典,並將索引作爲字典值進行復制? – James
@詹姆斯問題是左通用沒有指定確切的數據返回,以便不約束可能的解決方案,但我想一個布爾數組,通過索引簡單地標記重複項(如上建議)。 – BoltzmannBrain