2013-10-21 128 views
1

我想分析很多讓我的程序運行緩慢的數據。 我正在讀取從.txt文件到單元格數組的數據集。 我正在使用單元格數組來分類我的數據,這是兩個屬性的形式,我需要這些類是字符。提高嵌套循環的性能MATLAB

我想使用最近的均值分類器找到重新排列錯誤。 我有一個主要的外部循環,它遍歷我的數據集的每一行(數以萬計)。依次移除每一行,每次迭代一行。在刪除線條的每次迭代中重新計算兩個屬性的平均值。主掛點似乎成爲下一個部分,在那裏我需要計算在我的數據集的每一行:

  • 的數據之間上線(2個屬性值)和 的歐幾里得距離各自的均值我類。
  • 然後我想記錄其屬性平均值最接近的類,這將是它的分配類。
  • 最後,我想檢查這個分配的類是否是正確的 類。

目前這個循環看起來像這樣。

errorCount = 0; 
for l = 1:20000 
    closest = 100; 
    class = 0; 
    attribute1 = d{2}(l); 
    attribute2 = d{3}(l); 
    for m = 1:numel(classes) 
     dist = sqrt((attribute1-meansattr1(m))*(attribute1-meansattr1(m)) + (attribute2-meansattr2(m))*(attribute2-meansattr2(m))); 
     if dist < closest 
      closest = dist; 
      class = m; 
     end 
    end 

    if strcmp(d{1}(l),classes(class)) 
     %correct 
    else 
     errorCount = errorCount + 1; 
    end 
end 

d是我的細胞陣列,其中d{2}是保持我的屬性1值的列。我通過d{1}(1)獲取了該列中第一行的這些值。

classes是我的數據集中的獨特類,所以對於我的每個類,我計算它的歐幾里得距離。

meansattr1meansattr2是包含我的每個屬性的平均值的數組。當線被移除時,這些更新在外部循環的每次迭代中。

希望能幫助您理解我擁有的代碼。非常感謝在優化和加速這些計算方面的任何幫助。

+2

最簡單的速度改進是刪除'sqrt'調用。查找最近距離的平方與最近距離完全相同。 – paddy

回答

1

您基本上正在優化k-means算法的迭代部分,因此您可以參考my previous solution來獲得向量化的方法。但是,這裏是你如何處理你的問題和數據格式。

採取隨機數據集,像下面,

numClasses = 5; 
numPoints = 20e3; 
numDims = 2; 

classes = strsplit(num2str(1:numClasses)); 

% generate random data (expected error rate of (numClasses-1)/numClasses) 
d{1} = classes(randi(numClasses,numPoints,1)); 
d{2} = rand(numPoints,1); 
d{3} = rand(numPoints,1); 

% random initial class centers 
meansattr1 = rand(5,1); 
meansattr2 = rand(5,1); 

您的代碼,壓縮和存儲每個點的最接近的類ID及該類的距離:

closestDistance = zeros(numPoints,1); nearestCluster = zeros(numPoints,1); 
errorCount = 0; 
for l = 1:numPoints 
    closest = 100; iclass = 0; 
    attribute1 = d{2}(l); attribute2 = d{3}(l); 

    for m = 1:numel(classes) 
     dist = sqrt((attribute1-meansattr1(m))*(attribute1-meansattr1(m)) + ... 
      (attribute2-meansattr2(m))*(attribute2-meansattr2(m))); 
     if dist < closest 
      closest = dist; closestDistance(l) = closest; 
      iclass = m; nearestCluster(l) = iclass; 
     end 
    end 

    if ~strcmp(d{1}(l),classes(iclass)) 
     errorCount = errorCount + 1; 
    end 
end 

的矢量版本那麼上面是:

data = [d{2}(:) d{3}(:)]; 
meansattr = [meansattr1(:) meansattr2(:)]; 

kdiffs = bsxfun(@minus,data,permute(meansattr,[3 2 1])); 

allDistances = sqrt(sum(kdiffs.^2,2)); % no need to do sqrt 
allDistances = squeeze(allDistances); % Nx1xK => NxK 

[closestDistance,nearestCluster] = min(allDistances,[],2); % Nx1 

correctClassIds = str2num(char(d{1}(:))); 
errorCount = nnz(nearestCluster ~= correctClassIds); 

結果在errorCountclosestDistancenearestCluster等同於以前的解決方案。如代碼註釋所示,您可以刪除sqrt並在errorCountnearestCluster中獲得相同的結果。

說你想要做的更新meansattr1meansattr2下一步:

% Calculate the NEW cluster centers (mean the data) 
meansattr_new = zeros(numClasses,numDims); 
clustersizes = zeros(numClasses,1); 
for ii=1:numClasses, 
    indk = nearestCluster==ii; 
    clustersizes(ii) = nnz(indk); 
    meansattr_new(ii,:) = mean(data(indk,:))'; 
end 

meansattr1_next = meansattr_new(:,1); 
meansattr2_next = meansattr_new(:,2); 

把這一切都在while errorCount>THRESHfor jj = 1:MAXITER,你應該有你所追求的。

+1

謝謝,這樣做會有明顯的性能提升。我將不得不看看我的程序的其他方面,看看我可以做些類似的改進。 –

2

最簡單的速度改進是刪除sqrt呼叫。查找最近距離的平方與最近距離完全相同。

接下來,您可以矢量化內部循環。自從我做了任何MatLab以來,這已經很長時間了,所以我可能會弄錯下面的代碼,但是想法是將這兩個屬性變成長度爲numel(classes)的向量。然後,您可以直接計算差異並對其進行平方。

事情是這樣的:

d1 = attribute1 - meansattr1; 
d2 = attribute2 - meansattr2; 
[closest, class] = min(d1 .* d1 + d2 .* d2); 

順便說一句,這不是用class作爲變量一個偉大的想法(如果你甚至可以)。這是一個保留字。

+0

'距離最近的= strt(最近的)'在使用距離時丟失,變量包含平方距離。 – Daniel

+0

當然,但原始代碼並未顯示使用的地方。在所有循環之後很容易取得'sqrt' *。 – paddy

0

我開始用稻穀的解決方案,簡單的更換變量名稱:

[closest, cl] = min((d{2}(m) - meansattr1).^2 +(d{3}(m) - meansattr2).^ 
2); 

因此,我們有一個線環,共同的策略:做它的一個功能,並把它變成arrayfun:

[email protected](x)min((d{2}(x) - meansattr1).^2 +(d{3}(x) - meansattr2).^2) 
[sqclosest,cl]=arrayfun(f,1:numel(d{2})); 

%If necessary real distances could be calculated: 
%closest=sqrt(sqclosest) 

errorCount=sum(arrayfun(@(x,c)(1-strcmp(x,classes(c))),d{1},cl)) 

注意:請勿將「class」或任何其他保留字用於其他目的。