2014-03-13 94 views
1

我有兩個類型的字典:的Python:計算兩個類型的字典的餘弦相似度更快

d1 = {1234: 4, 125: 7, ...} 
d2 = {1234: 8, 1288: 5, ...} 

http://stardict.sourceforge.net/Dictionaries.php下載的長度爲10至40000。變化要計算我使用此功能的餘弦相似性:

from scipy.linalg import norm 
def simple_cosine_sim(a, b): 
    if len(b) < len(a): 
     a, b = b, a 

    res = 0 
    for key, a_value in a.iteritems(): 
     res += a_value * b.get(key, 0) 
    if res == 0: 
     return 0 

    try: 
     res = res/norm(a.values())/norm(b.values()) 
    except ZeroDivisionError: 
     res = 0 
    return res 

可以更快地計算相似度嗎?

UPD:使用Cython +重寫代碼+速度提高15%。感謝@Davidmh

from scipy.linalg import norm 

def fast_cosine_sim(a, b): 
    if len(b) < len(a): 
     a, b = b, a 

    cdef long up, key 
    cdef int a_value, b_value 

    up = 0 
    for key, a_value in a.iteritems(): 
     b_value = b.get(key, 0) 
     up += a_value * b_value 
    if up == 0: 
     return 0 
    return up/norm(a.values())/norm(b.values()) 
+0

我已經評論了你用Cython代碼,增加了一種替代方法。我希望這有幫助。 – Davidmh

回答

1

如果索引不是太高,可以將每個字典轉換爲數組。如果它們非常大,則可以使用稀疏數組。那麼,餘弦相似性只會使它們兩者相乘。如果您需要重複使用同一個字典進行多次計算,則此方法的性能最佳。

如果這不是一個選項,只要您註釋a_value和b_value,Cython應該是非常快的。

編輯: 看看你的Cython重寫,我看到了一些改進。第一件事是做一個cython -a來生成彙編的HTML報告,看看哪些事情已經加速,哪些沒有。首先,你定義「up」爲止,但是你總結了整數。另外,在你的例子中,鍵是整數,但是你將它們聲明爲double。另一個簡單的事情是將輸入鍵入爲字符串。

此外,檢查C代碼,似乎有一些沒有檢查,您可以通過使用@ cython.nonechecks(False)禁用。

實際上,字典的實現是非常有效的,所以在一般情況下,你可能不會比這更好。如果您需要擠壓最出你的代碼,也許是值得的C API替換一些電話:http://docs.python.org/2/c-api/dict.html

cpython.PyDict_GetItem(a, key) 

但是,你將負責引用計數和的PyObject *鑄造爲int的一個可疑的表現收益。

任何方式,代碼的開頭是這樣的:

cimport cython 

@cython.nonecheck(False) 
@cython.cdivision(True) 
def fast_cosine_sim(dict a, dict b): 
    if len(b) < len(a): 
     a, b = b, a 

    cdef int up, key 
    cdef int a_value, b_value 

還有另一個問題:是你dicionaries大?因爲如果它們不是,規範的計算實際上可能是一個重要的開銷。

編輯2: 另一種可能的方法是隻查看必要的鍵。說:

from scipy.linalg import norm 
cimport cython 

@cython.nonecheck(False) 
@cython.cdivision(True) 
def fast_cosine_sim(dict a, dict b): 
    cdef int up, key 
    cdef int a_value, b_value 

    up = 0 
    for key in set(a.keys()).intersection(b.keys()): 
     a_value = a[key] 
     b_value = b[key] 
     up += a_value * b_value 
    if up == 0: 
     return 0 
    return up/norm(a.values())/norm(b.values()) 

這在Cython中非常高效。實際的表現可能取決於鍵之間有多少重疊。

+0

該詞典可以包含40000多個項目。因此,將它們轉換爲一個集合並找到交點不會很快。鑰匙是「長」型。並且'a_value * b_value'的總和可以大於int值。我認爲Cython不能自動轉換類型(比如Python),這就是爲什麼我把'up'定義爲'long'的原因。 –

1

從算法的角度來看,沒有。你已經處於複雜的O(N)。雖然有一些計算技巧可以使用。

您可以使用多處理模塊將a_value * b.get(key, 0)乘法調度給幾個工人,從而利用您擁有的所有機器核心。請注意,您將不會使用線程獲得此效果,因爲Python具有全局解釋器鎖定。

最簡單的方法是使用池對象的multiproccess.Poolmap方法。

我強烈建議使用Python內置的cProfiler來檢查代碼中的熱點。這很容易。只要運行:

python -m cProfile myscript.py

+0

問題是''simple_cosine_sim'從'multiproccess.Pool'中的map中運行的函數調用:) –

+1

去核心?剖析代碼並檢查乘法是否是熱點。如果確實如此,你可以使用PyCUDA。 – rafgoncalves

+0

我會朝這個方向看 –