2013-11-20 55 views
3

我在sklearn.cluster模塊(here are the docs)中運行一個名爲MeanShift()的聚類算法。我正在處理的對象在三維空間中分佈有310,057個點。我使用的計算機總共有128Gb的內存,所以當我遇到以下錯誤時,我很難相信我實際上正在使用它。Python MeanShift內存錯誤

[[email protected] ~]$ python meanshifttest.py 
Traceback (most recent call last): 
    File "meanshifttest.py", line 13, in <module> 
    ms = MeanShift().fit(X) 
    File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 280, in fit 
    cluster_all=self.cluster_all) 
    File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 99, in mean_shift 
bandwidth = estimate_bandwidth(X) 
    File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 45, in estimate_bandwidth 
d, _ = nbrs.kneighbors(X, return_distance=True) 
    File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/neighbors/base.py", line 313, in kneighbors 
return_distance=return_distance) 
    File "binary_tree.pxi", line 1313, in sklearn.neighbors.kd_tree.BinaryTree.query (sklearn/neighbors/kd_tree.c:10007) 
    File "binary_tree.pxi", line 595, in sklearn.neighbors.kd_tree.NeighborsHeap.__init__ (sklearn/neighbors/kd_tree.c:4709) 
MemoryError 

我跑看起來像這樣的代碼:

from sklearn.cluster import MeanShift 
import asciitable 
import numpy as np 
import time 

data = asciitable.read('./multidark_MDR1_FOFID85000000000_ParticlePos.csv',delimiter=',') 
x = [data[i][2] for i in range(len(data))] 
y = [data[i][3] for i in range(len(data))] 
z = [data[i][4] for i in range(len(data))] 
X = np.array(zip(x,y,z)) 

t0 = time.time() 
ms = MeanShift().fit(X) 
t1 = time.time() 
print str(t1-t0) + " seconds." 
labels = ms.labels_ 
print set(labels) 

會有人有發生的事情什麼想法?不幸的是我不能改變聚類算法,因爲這是我發現的唯一一個除了不接受鏈接長度/ k個數量的聚類/先驗信息外還做得很好的聚類算法。

在此先感謝!

**更新: 我看着文檔的詳細一點,它說以下內容:

可擴展性:

由於這種實現使用平面內核和
一球樹查找每個內核的成員,複雜度將爲
到更低維度的O(T * n * log(n)),其中n爲樣本數
和T的點數。在更高維度中,複雜度將趨向於O(T * n^2)。

通過使用較少的種子可以提高可伸縮性,例如,通過在get_bin_seeds函數中使用較高的min_bin_freq值,可以使用

請注意,estimate_bandwidth函數的可擴展性遠低於均值偏移算法,並且如果使用它將會是瓶頸。

這似乎使一些感覺,因爲如果你詳細看一下錯誤時被抱怨estimate_bandwidth。這是否表明我只是爲算法使用了太多的粒子?

+0

什麼是內存監視器,如「top」或「free」顯示你? (在頂部,按常駐內存排序:按S然後按Q.) – 9000

+0

是的,所以我這樣做了,而且我只使用了總內存(即128Gb)的0.2%。它也幾乎是瞬間失敗 - 這表明它是別的東西。我看不出它如此快速地使用那麼多的RAM。 – astromax

+0

你有沒有嘗試以某種方式減少問題的大小?是否有一個已知的情況,一切正常,您可以測量使用的內存量? – 9000

回答

3

從錯誤信息判斷,我懷疑它試圖計算點之間的所有成對距離,這意味着它需要310057²浮點數或716GB的RAM。

您可以通過爲MeanShift構造函數提供一個明確的bandwidth參數來禁用此行爲。

這可以說是一個錯誤;考慮爲它提交一份錯誤報告。 (該scikit學習機組人員,其中包括我自己,最近一直在努力擺脫在不同的地方,這些過於昂貴的距離計算的,但顯然沒有人看着均值漂移)。

編輯:上述的計算是3倍,但內存使用量確實是二次的。我剛剛在scikit-learn的開發版本中解決了這個問題。

+0

非常感謝這篇文章。這真的是真的需要這麼多的RAM來運行這個算法嗎?我已經做了大約30,000個粒子,我可以發誓說它在4GB內存的不同工作計算機上工作。我會用sklearn提交bug。 – astromax

+1

@astromax:這個計算按照當前實現的方式創建一個n2 64位浮點數的數組。 30k²= 9e8,次數8使得6.7GB,所以可能是可行的,但二次空間增長*快*。 –

+0

哎呀。是否有非n平方的方法會進入代碼? – astromax