由於intersect1d
每次都對數組進行排序,因此效率很低。
在這裏,您必須將交叉點和每個樣本一起掃過以構建新的交叉點,這可以在線性時間內完成,從而維護順序。
這樣的任務通常必須用低級別的例程手動調整。
下面的方式來做到這一點與numba
:
from numba import njit
import numpy as np
@njit
def drop_missing(intersect,sample):
i=j=k=0
new_intersect=np.empty_like(intersect)
while i< intersect.size and j < sample.size:
if intersect[i]==sample[j]: # the 99% case
new_intersect[k]=intersect[i]
k+=1
i+=1
j+=1
elif intersect[i]<sample[j]:
i+=1
else :
j+=1
return new_intersect[:k]
現在樣本:
n=10**7
ref=np.random.randint(0,n,n)
ref.sort()
def perturbation(sample,k):
rands=np.random.randint(0,n,k-1)
rands.sort()
l=np.split(sample,rands)
return np.concatenate([a[:-1] for a in l])
samples=[perturbation(ref,100) for _ in range(10)] #similar samples
而對於10個樣品
def find_intersect(samples):
intersect=samples[0]
for sample in samples[1:]:
intersect=drop_missing(intersect,sample)
return intersect
In [18]: %time u=find_intersect(samples)
Wall time: 307 ms
In [19]: len(u)
Out[19]: 9999009
這種方式來看,它似乎工作可以在5分鐘內完成,超出加載時間。
數字的範圍是什麼?它們只是整數嗎?他們只是積極的嗎? – Divakar
他們是積極的64位整數。 – dshin
他們的範圍是什麼?事先知道嗎?另外,它們在每個陣列中都是唯一的嗎? – Divakar