2017-05-09 96 views
0

我想訓練一個簡單的MLP來近似y = f(a,b,c)。 我的代碼如下。PyTorch網絡測試代碼不工作

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
# hyper parameters 
input_size = 3 
output_size = 1 
num_epochs = 50 
learning_rate = 0.001 

# Netork definition 
class FeedForwardNet(nn.Module): 
    def __init__(self, l1_size, l2_size): 
     super(FeedForwardNet, self).__init__() 
     self.fc1 = nn.Linear(input_size, l1_size) 
     self.relu1 = nn.ReLU() 
     self.fc2 = nn.Linear(l1_size, l2_size) 
     self.relu2 = nn.ReLU() 
     self.fc3 = nn.Linear(l2_size, output_size) 

    def forward(self, x): 
     out = self.fc1(x) 
     out = self.relu1(out) 
     out = self.fc2(out) 
     out = self.relu2(out) 
     out = self.fc3(out) 
     return out 

model = FeedForwardNet(5 , 3) 

# sgd optimizer 
optimizer = torch.optim.SGD(model.parameters(), learning_rate, momentum=0.9) 

for epoch in range(11): 
    print ('Epoch ', epoch) 
    for i in range(trainX_light.shape[0]): 
    X = Variable(torch.from_numpy(trainX_light[i]).view(-1, 3)) 
    Y = Variable(torch.from_numpy(trainY_light[i]).view(-1, 1)) 
    # forward 
    optimizer.zero_grad() 
    output = model(X) 

    loss = (Y - output).pow(2).sum() 
    print (output.data[0,0]) 
    loss.backward() 
    optimizer.step() 
    totalnorm = 0 
    for p in model.parameters(): 
     modulenorm = p.grad.data.norm() 
     totalnorm += modulenorm ** 2 
     totalnorm = math.sqrt(totalnorm) 

    print (totalnorm) 

    # validation code 
    if (epoch + 1) % 5 == 0: 
    print (' test points',testX_light.shape[0]) 
    total_loss = 0 
    for t in range(testX_light.shape[0]): 
     X = Variable(torch.from_numpy(testX_light[t]).view(-1, 3)) 
     Y = Variable(torch.from_numpy(testY_light[t]).view(-1, 1)) 
     output = model(X) 
     loss = (Y - output).pow(2).sum() 
     print (output.data[0,0]) 
     total_loss += loss 
    print ('epoch ', epoch, 'avg_loss ', total_loss.data[0]/testX_light.shape[0]) 

print ('Done') 

,我現在已經是問題,驗證代碼

輸出=模型(X)

總是產生完全相同的輸出值(我想這值是某種垃圾)。我不確定我在這部分做了什麼錯誤。有人能幫我弄清楚我的代碼中的錯誤嗎?

+0

您只有3個尺寸的輸入? – Kashyap

+0

@Kashyap ya只有3個維度 – Arul

回答

0

回答我自己的問題。網絡產生隨機值(以及後來的inf)的原因是爆炸性的梯度問題。剪切梯度(torch.nn.utils.clip_grad_norm(model.parameters(), 0.1))幫助。

+0

您是否在修剪漸變之前嘗試過較小的學習速度? – Kashyap

+0

雅,我做到了。將它從0.001降低到0.000。也許我可以嘗試降低它甚至更多。 – Arul