代碼來自輸入LogSoftMax代碼:https://github.com/torch/nn/blob/master/lib/THNN/generic/LogSoftMax.c試圖瞭解其計算漸變WRT在火炬
我沒有看到這個代碼是如何計算梯度w.r.t到輸入模塊LogSoftMax。我感到困惑的是兩個for循環正在做什麼。
for (t = 0; t < nframe; t++)
{
sum = 0;
gradInput_data = gradInput_data0 + dim*t;
output_data = output_data0 + dim*t;
gradOutput_data = gradOutput_data0 + dim*t;
for (d = 0; d < dim; d++)
sum += gradOutput_data[d];
for (d = 0; d < dim; d++)
gradInput_data[d] = gradOutput_data[d] - exp(output_data[d])*sum;
}
}