我在下面寫了一段代碼,以瞭解多處理(MP)及其可能的速度增益與非MP版本的關係。除了突出顯示的位置外,這兩個函數幾乎相同(對不起,並不知道突出顯示代碼區域的更好方法)。計時Python代碼
該代碼嘗試識別數組列表(此處爲1-D)中冗餘條目的索引。由兩個函數返回的id列表是相同的,但我的問題是關於時間差。正如你所看到的,我已經嘗試了時間a)map函數,b)列表擴展和c)兩種情況下的整個while循環。與非MP版本相比,MP在map
區域提供更好的加速,但其redun_ids.extend(...)
速度較慢。這實際上是迫使MP版本的整體速度增益下降。
有沒有什麼辦法可以重寫MP版本的redun_ids.extend(...)
部件以獲得與非MP版本相同的時序?
#!/usr/bin/python
import multiprocessing as mproc
import sys
import numpy as np
import random
import time
def matdist(mats):
mat1 = mats[0]
mat2 = mats[1]
return np.allclose(mat1, mat2, rtol=1e-08, atol=1e-12)
def mp_remove_redundancy(larrays):
"""
remove_redundancy : identify arrays that are redundant in the
input list of arrays
"""
llen = len(larrays)
redun_ids = list()
templist = list()
i = 0
**pool = mproc.Pool(processes=10)**
st1=time.time()
while 1:
currarray = larrays[i]
if i not in redun_ids:
templist.append(currarray)
#replication to create list of arrays
templist = templist*(llen-i-1)
**chunksize = len(templist)/10
if chunksize == 0:
chunksize = 1**
#clslist is a result object here
st=time.time()
**clslist = pool.map_async(matdist, zip(larrays[i+1:],
templist), chunksize)**
print 'map time:', time.time()-st
**outlist = clslist.get()[:]**
#j+1+i gives the actual id num w.r.t to whole list
st=time.time()
redun_ids.extend([j+1+i for j, x in
enumerate(outlist) if x])
print 'Redun ids extend time:', time.time()-st
i = i + 1
del templist[:]
del outlist[:]
if i == (llen - 1):
break
print 'Time elapsed in MP:', time.time()-st1
pool.close()
pool.join()
del clslist
del templist
return redun_ids[:]
#######################################################
def remove_redundancy(larrays):
llen = len(larrays)
redun_ids = list()
clslist = list()
templist = list()
i = 0
st1=time.time()
while 1:
currarray = larrays[i]
if i not in redun_ids:
templist.append(currarray)
templist = templist*(llen-i-1)
st = time.time()
clslist = map(matdist, zip(larrays[i+1:],
templist))
print 'map time:', time.time()-st
#j+1+i gives the actual id num w.r.t to whole list
st=time.time()
redun_ids.extend([j+1+i for j, x in
enumerate(clslist) if x])
print 'Redun ids extend time:', time.time()-st
i = i + 1
#clear temp vars
del clslist[:]
del templist[:]
if i == (llen - 1):
break
print 'Tot non MP time:', time.time()-st1
del clslist
del templist
return redun_ids[:]
###################################################################
if __name__=='__main__':
if len(sys.argv) != 2:
sys.exit('# entries')
llen = int(sys.argv[1])
#generate random numbers between 1 and 10
mylist=[np.array([round(random.random()*9+1)]) for i in range(llen)]
print 'The input list'
print 'no MP'
rrlist = remove_redundancy(mylist)
print 'MP'
rrmplist = mp_remove_redundancy(mylist)
print 'Two lists match : {0}'.format(rrlist==rrmplist)
我想這解釋了我在自己的答案中表達的疑惑。 :) – 2012-08-17 08:59:06