2017-10-04 72 views
2

我有一個排序numpy數組的列表。計算這些數組的排序交集的最有效方法是什麼?在我的應用程序中,我期望數組的數量少於10^4,我期望單個數組的長度小於10^7,並且我期望交集的長度接近於p * N,其中N是最大陣列的長度,其中0.99 < p < = 1.0。數組從磁盤加載,如果它們不能一次全部裝入內存,可以批量加載。排序numpy數組的交集

一個快速和骯髒的方法是反覆調用numpy.intersect1d()。這似乎效率低下,但因爲intersect1d()沒有利用數組排序的事實。

+0

數字的範圍是什麼?它們只是整數嗎?他們只是積極的嗎? – Divakar

+0

他們是積極的64位整數。 – dshin

+0

他們的範圍是什麼?事先知道嗎?另外,它們在每個陣列中都是唯一的嗎? – Divakar

回答

1

由於intersect1d每次都對數組進行排序,因此效率很低。

在這裏,您必須將交叉點和每個樣本一起掃過以構建新的交叉點,這可以在線性時間內完成,從而維護順序。

這樣的任務通常必須用低級別的例程手動調整。

下面的方式來做到這一點與numba

from numba import njit 
import numpy as np 

@njit 
def drop_missing(intersect,sample): 
    i=j=k=0 
    new_intersect=np.empty_like(intersect) 
    while i< intersect.size and j < sample.size: 
      if intersect[i]==sample[j]: # the 99% case 
       new_intersect[k]=intersect[i] 
       k+=1 
       i+=1 
       j+=1 
      elif intersect[i]<sample[j]: 
       i+=1 
      else : 
       j+=1 
    return new_intersect[:k] 

現在樣本:

n=10**7 
ref=np.random.randint(0,n,n) 
ref.sort() 

def perturbation(sample,k): 
    rands=np.random.randint(0,n,k-1) 
    rands.sort() 
    l=np.split(sample,rands) 
    return np.concatenate([a[:-1] for a in l]) 

samples=[perturbation(ref,100) for _ in range(10)] #similar samples 

而對於10個樣品

def find_intersect(samples): 
    intersect=samples[0] 
    for sample in samples[1:]: 
     intersect=drop_missing(intersect,sample) 
    return intersect     

In [18]: %time u=find_intersect(samples) 
Wall time: 307 ms 

In [19]: len(u) 
Out[19]: 9999009  

這種方式來看,它似乎工作可以在5分鐘內完成,超出加載時間。