2014-12-10 42 views
2

是否有一個內置的方法可以幫助我有效地實現以下內容:給定一個數組,我需要一個數組列表,每個列表的索引都指向數組的不同唯一值?如何有效地獲取唯一值的索引列表?

如果f是所需要的功能,

b = f(a) 

u, idxs = unique(a) 

然後

b[i] == where(idxs==i)[0] 

我知道pandas.Series.groupby()可以做到這一點,但它可能不會是有效的當有超過10^5個獨特整數時創建一個字典。

+0

fyi,pandas.Series對象也有一個「獨特」的方法。 – 2014-12-11 00:22:53

回答

2

如果你有numpy的> = 1.9,你可以這樣做:

>>> a = np.random.randint(5, size=10) 
>>> a 
array([0, 2, 4, 4, 2, 4, 4, 3, 2, 1]) 
>>> unq, unq_inv, unq_cnt = np.unique(a, return_inverse=True, return_counts=True) 
>>> np.split(np.argsort(unq_inv), np.cumsum(unq_cnt[:-1])) 
[array([0]), array([9]), array([1, 4, 8]), array([7]), array([2, 3, 5, 6])] 
>>> unq 
array([0, 1, 2, 3, 4]) 

在早期版本中,你可以得到做一個額外的計數:

>>> unq_cnt = np.bincount(unq_inv) 

此外,如果您想確保每個值的索引都已排序,我認爲您需要使用穩定的排序,例如np.argsort(unq_inv, kind='mergesort')


你似乎什麼是後的思考,我認爲這是減少呼叫昂貴的功能,我不認爲你需要做你的要求。說你的函數平方,你可以簡單地做:

>>> unq, unq_inv = np.unique(a, return_inverse=True) 
>>> f_unq = unq**2 
>>> f_a = f_unq[unq_inv] 
>>> a 
array([0, 2, 4, 4, 2, 4, 4, 3, 2, 1]) 
>>> f_a 
array([ 0, 4, 16, 16, 4, 16, 16, 9, 4, 1]) 
0

也許這樣做:

s = argsort(a) 
d = diff(a[s]) 
starts = where(d)[0] 
f = [s[starts[i:i+1]] for i in xrange(len(a))] 

(代碼未被選中)

0
def foo(a): 
    I=np.arange(a.shape[0]) 
    d={} 
    while a.shape[0]: 
    x = a[0] 
    ii = a==x 
    d[x] = I[ii] 
    a = a[~ii] 
    I = I[~ii] 
    return d 

In [767]: a 
Out[767]: array([4, 4, 3, 0, 0, 2, 1, 1, 0, 3]) 

In [768]: foo(a) 
Out[768]: 
{0: array([3, 4, 8]), 
1: array([6, 7]), 
2: array([5]), 
3: array([2, 9]), 
4: array([0, 1])} 

這是不是你想要的那種字典?

對於小型a這工作正常。

等效字典建築功能爲:

def foo1(a): 
    unq = np.unique(a) 
    return {i:np.where(a==i)[0] for i in unq} 

副手我看不出unq_inv有助於構建字典。

foofoo1慢大約30%。我希望通過減少被搜索的數組,每次計算一個值,我可能會獲得一些速度。但它看起來像額外的簿記咀嚼時間。並且where時間可能不會對a的長度敏感。

對於a2=np.random.randint(5000,size=100000)運行時間約爲2-3秒。

np.random.randint(50000,size=1000000)花費時間太長(對於任一版本)。


在進一步的實驗,使用collections.defaultdict一個 '啞' 的方法要快得多(20X):

def food(a): 
    d = defaultdict(list) 
    for i,j in enumerate(a): 
     d[j].append(i) 
    return d 

的 '過大'(1000000)陣列只需要1.1秒;

相關問題