import numpy as np
def ismember_rows(a, b):
'''Equivalent of 'ismember' from Matlab
a.shape = (nRows_a, nCol)
b.shape = (nRows_b, nCol)
return the idx where b[idx] == a
'''
return np.nonzero(np.all(b == a[:,np.newaxis], axis=2))[1]
a = np.array([[4, 6],[2, 6],[5, 2]])
b = np.array([[1, 7],[1, 8],[2, 6],[2, 1],[2, 4],[4, 6],[4, 7],[5, 9],[5, 2],[5, 1]])
idx = ismember_rows(a, b)
print idx
print np.all(b[idx] == a)
打印
array([5, 2, 8])
True
è...我用廣播
------------------------ - [更新] ------------------------------
def ismember(a, b):
return np.flatnonzero(np.in1d(b[:,0], a[:,0]) & np.in1d(b[:,1], a[:,1]))
a = np.array([[4, 6],[2, 6],[5, 2]])
b = np.array([[1, 7],[1, 8],[2, 6],[2, 1],[2, 4],[4, 6],[4, 7],[5, 9],[5, 2],[5, 1]])
a2 = np.tile(a,(2000,1))
b2 = np.tile(b,(2000,1))
%timeit timeit in1d_index(a2, b2)
# 100 loops, best of 3: 8.74 ms per loop
%timeit ismember(a2, b2)
# 100 loops, best of 3: 8.5 ms per loop
np.all(in1d_index(a2, b2) == ismember(a2, b2))
# True
正如unutbu所說,指數是按遞增順序返回
什麼是你的數組的dtype?是'a'和'b'的長度〜1M嗎? 「索引」中的值的順序對您來說很重要嗎? – unutbu
它們都是dtype('int64')。 b是長度〜1M,並且a是長度〜750k。 a中的每個條目都將在b中,但不是相反的。理想情況下,索引輸出將與b的值相同,其值顯示a中的索引。 – claudiaann1
對不起...我的意思是輸出應該和a一樣長。相反,這是沒有意義的。 – claudiaann1