我試圖使用梯度下降實現多項邏輯迴歸,但是我的成本函數開始爲權重指定NaN
值。有人可以告訴我做錯了什麼嗎?計算給我NaN
function [ cost ] = costFunctionMultiNominal(inputX,resultY,weights)
%UNTITLED8 Calculates the cost for gradient descent,assumes inputX has one
%additional feature for constant and Weights is a classes X features matrix
[rows,cols] = size(inputX);
numOfClasses = size(weights,1);
summation = 0;
for i=1:rows
classLevelSummation = 0;
for j=1:numOfClasses
if resultY(i)==j
denominatorSum = 0;
for l=1:numOfClasses
denominatorSum = denominatorSum + exp((inputX(i,:)*weights(l,:)')-4444);
end
**classLevelSummation = classLevelSummation + log(exp(inputX(i,:)*weights(j,:)'-4444)/denominatorSum);**
end
end
summation = summation + classLevelSummation;
end
cost = summation/(-rows);
end
這裏是權重更新功能:
function [ Weights ] =
getWeightsUsingGradientDescentMultiNominal(trainingX,resultY,iterMax,Alpha,weight0,lambda)
%Returns updated weights through gradient descent,weight0 are the intial randomized weights
% Detailed explanation goes here
rows = size(trainingX,1);
cols = size(trainingX,2)+1;
Weights = weight0;
numOfClasses = size(Weights,1);
%Adding one's to the input data for the constant terms
a = ones(rows,1);
X = [a trainingX];
%Each column corresponds to one weight, updating weights column wise:
%Also plot cst function simultaneously
tempCost = 0;
display(costFunctionMultiNominal(X,resultY,Weights));
plot(1,costFunctionMultiNominal(X,resultY,Weights),'r');
hold on;
for n=1:iterMax
%Have to do this for all classes, i.e rows in weigths
for j = 1:numOfClasses
%First Calculating the Sigma over rows for all X
summation = zeros(1,cols);
for i=1:rows
p = -1 * calculatePofJMultiNominal(X(i,:),Weights,j);
if resultY(i) == j
p = 1 + p;
end
summation = summation + X(i,:)*p;
end
Weights(j,:) = Weights(j,:) - (Alpha)*(summation/(-rows) + lambda*Weights(j,:));
end
cost = costFunctionMultiNominal(X,resultY,Weights);
display(cost);
costDiff = tempCost - cost;
if i~=0 && abs(costDiff)/cost <= 0.0001
display('Breaking because of cost very less!');
break;
end
tempCost = cost;
plot(i,cost,'r');
end
hold off;
end
據我瞭解,在NaN
是因爲在指數方面的大量到來。我嘗試從指數(-4444)減少大量數據,但無濟於事。
我試圖DBSTOP如果楠,告訴我它停在成本函數的線(在上面的代碼中粗體):
classLevelSummation = classLevelSummation + log(exp(inputX(i,:)*weights(j,:)'-4444)/denominatorSum);
classLevelSummation變成NaN,即使刪除了大量的恆定值-4444
'DBSTOP IF NANINF'可以幫助您準確找出NaN來自哪裏。 – jez 2014-10-08 19:59:44
嘗試過,請檢查更新 – Sudh 2014-10-08 20:28:43