2017-06-20 161 views

回答

0

不是最漂亮的一段代碼,但是這是我收集我個人用通過PyTorch論壇和文檔去後。可以肯定有更好的方法來處理分類 - 恢復部分,但我選擇它在網絡本身

class Encoder(nn.Module): 
    def __init__(self, vocab_size, embedding_size, embedding_vectors=None, tune_embeddings=True, use_gru=True, 
       hidden_size=128, num_layers=1, bidrectional=True, dropout=0.6): 
     super(Encoder, self).__init__() 
     self.embed = nn.Embedding(vocab_size, embedding_size, padding_idx=0) 
     self.embed.weight.requires_grad = tune_embeddings 
     if embedding_vectors is not None: 
      assert embedding_vectors.shape[0] == vocab_size and embedding_vectors.shape[1] == embedding_size 
      self.embed.weight = nn.Parameter(torch.FloatTensor(embedding_vectors)) 
     cell = nn.GRU if use_gru else nn.LSTM 
     self.rnn = cell(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, 
         batch_first=True, bidirectional=True, dropout=dropout) 

    def forward(self, x, x_lengths): 
     sorted_seq_lens, original_ordering = torch.sort(torch.LongTensor(x_lengths), dim=0, descending=True) 
     ex = self.embed(x[original_ordering]) 
     pack = torch.nn.utils.rnn.pack_padded_sequence(ex, sorted_seq_lens.tolist(), batch_first=True) 
     out, _ = self.rnn(pack) 
     unpacked, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True) 
     indices = Variable(torch.LongTensor(np.array(unpacked_len) - 1).view(-1, 1) 
                     .expand(unpacked.size(0), unpacked.size(2)) 
                     .unsqueeze(1)) 
     last_encoded_states = unpacked.gather(dim=1, index=indices).squeeze(dim=1) 
     scatter_indices = Variable(original_ordering.view(-1, 1).expand_as(last_encoded_states)) 
     encoded_reordered = last_encoded_states.clone().scatter_(dim=0, index=scatter_indices, src=last_encoded_states) 
     return encoded_reordered 
+0

剛剛意識到我回答了一個6個月的問題=( – chiragjn