2017-08-19 61 views
0

我試圖在Pytorch中創建一個基本的二進制分類器,該分類器可以分類我的玩家在遊戲Pong中的右側還是左側。輸入是一個1x42x42的圖像,標籤是我玩家的一面(右= 1或左= 2)。代碼:PyTorch RuntimeError:斷言`cur_target> = 0 && cur_target <n_classes'失敗

class Net(nn.Module): 
    def __init__(self, input_size, hidden_size, num_classes): 
     super(Net, self).__init__() 
     self.fc1 = nn.Linear(input_size, hidden_size) 
     self.relu = nn.ReLU() 
     self.fc2 = nn.Linear(hidden_size, num_classes) 

    def forward(self, x): 
     out = self.fc1(x) 
     out = self.relu(out) 
     out = self.fc2(out) 
     return out 

net = Net(42 * 42, 100, 2) 

# Loss and Optimizer 
criterion = nn.CrossEntropyLoss() 
optimizer_net = torch.optim.Adam(net.parameters(), 0.001) 
net.train() 

while True: 
    state = get_game_img() 
    state = torch.from_numpy(state) 

    # right = 1, left = 2 
    current_side = get_player_side() 
    target = torch.LongTensor(current_side) 
    x = Variable(state.view(-1, 42 * 42)) 
    y = Variable(target) 
    optimizer_net.zero_grad() 
    y_pred = net(x) 
    loss = criterion(y_pred, y) 
    loss.backward() 
    optimizer.step() 

的錯誤,我得到:

File "train.py", line 109, in train 
    loss = criterion(y_pred, y) 
    File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 206, in __call__ 
    result = self.forward(*input, **kwargs) 
    File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/modules/loss.py", line 321, in forward 
    self.weight, self.size_average) 
    File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/functional.py", line 533, in cross_entropy 
    return nll_loss(log_softmax(input), target, weight, size_average) 
    File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/functional.py", line 501, in nll_loss 
    return f(input, target) 
    File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/_functions/thnn/auto.py", line 41, in forward 
    output, *self.additional_args) 
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed. at /py/conda-bld/pytorch_1493676237139/work/torch/lib/THNN/generic/ClassNLLCriterion.c:57 

回答

0

它看起來像PyTorch預計(在你的情況下0/1)得到從零開始的標籤,你可能有給它一個基於標籤(1/2)

+0

是的,我知道但是當我嘗試它時,它不起作用。無論如何,我發現了問題並將其改回。現在,由於某種原因,即使標籤始終爲1,模型也會返回值<0.5。 –

+0

解決了問題:) –

+0

Shani,您是如何最終解決的?我已經檢查過標籤是否爲0,並確保網絡輸出與我的標籤尺寸相匹配,但仍然出現此錯誤。 –

0

對於大多數深度學習庫,目標(或標籤)的應該開始從0

這意味着你的目標應該是在[0,N)用正類的範圍內。