2017-10-10 79 views
12

以下代碼將for循環並行化。如何在python joblib中寫入共享變量

import networkx as nx; 
import numpy as np; 
from joblib import Parallel, delayed; 
import multiprocessing; 

def core_func(repeat_index, G, numpy_arrary_2D): 
    for u in G.nodes(): 
    numpy_arrary_2D[repeat_index][u] = 2; 
    return; 

if __name__ == "__main__": 
    G = nx.erdos_renyi_graph(100000,0.99); 
    nRepeat = 5000; 
    numpy_array = np.zeros([nRepeat,G.number_of_nodes()]); 
    Parallel(n_jobs=4)(delayed(core_func)(repeat_index, G, numpy_array) for repeat_index in range(nRepeat)); 
    print(np.mean(numpy_array)); 

如可以看到的,要打印的預期值是2。然而,當我運行上的簇(多核,共享存儲器)碼,它返回0.0。

我認爲問題在於每個工作人員都創建自己的numpy_array對象副本,並且在主函數中創建的副本不會更新。我怎樣才能修改代碼,使numpy數組numpy_array可以更新?

+0

那麼,你有沒有決定答案? ;-) –

回答

3

joblib使用過程默認的多處理池,its manual說:

引擎蓋下,並行對象創建一個多處理池 叉在多個進程Python解釋執行每個的 列表中的項目。延遲函數是一個簡單的技巧,它可以通過函數調用 語法創建一個元組(函數,參數,kwargs)。

這意味着,每個進程都繼承了數組的原始狀態,但無論它寫入內存中的哪個進程都會在進程退出時丟失。只有函數結果被傳遞迴調用(主)進程。但是你不返回任何東西,所以返回None

要使共享數組變得更加可修改,您有兩種方法:使用線程和使用共享內存。


線程與進程不同,共享內存。所以你可以寫入數組,每個工作都會看到這個變化。當你運行它

Parallel(n_jobs=4, backend="threading")(delayed(core_func)(repeat_index, G, numpy_array) for repeat_index in range(nRepeat)); 

$ python r1.py 
2.0 

然而,當你會寫複雜的事情到陣列中,請務必妥善處理按照joblib手冊,它是做這樣鎖定數據或數據片段,或者您將遇到競爭條件(谷歌它)。

還仔細閱讀有關GIL,因爲Python中的計算多線程是有限的(不像I/O多線程)。


如果您仍然需要進程(例如因爲GIL),您可以將該數組放入共享內存中。

這是一個比較複雜的話題,但joblib + numpy shared memory example也顯示在joblib手冊中。

0

正如Sergey在他的回答中寫道的,流程不共享狀態和內存。這就是爲什麼你沒有看到預期的答案。

線程共享狀態和內存空間,因爲它們在同一進程下運行。如果您有很多I/O操作,這很有用。它不會讓你更多的處理能力(更多CPU),因爲進程間通信的GIL

一種技術的使用是經理代理對象。您創建一個管理器對象,用於在進程之間同步資源。

Manager()返回的管理器對象控制一個服務器進程,該進程持有Python對象並允許其他進程使用代理來操縱它們。

我沒有測試此代碼(我沒有你使用的模塊),它可能需要更多的修改代碼,但使用Manager對象就應該是這樣的

if __name__ == "__main__": 
    G = nx.erdos_renyi_graph(100000,0.99); 
    nRepeat = 5000; 

    manager = multiprocessing.Manager() 
    numpys = manager.list(np.zeros([nRepeat, G.number_of_nodes()]) 

    Parallel(n_jobs=4)(delayed(core_func)(repeat_index, G, numpys, que) for repeat_index in range(nRepeat)); 
    print(np.mean(numpys)); 
+0

這裏的數據結構在語義上是浮點數列表(矩陣/表),但實際上是numpy.float64值的numpy.array的'numpy.array'的一個實例。在通過默認管理器同步這些自定義數據類型時會遇到許多麻煩,默認管理器只支持少量標量值,本地列表和字典。 –