我正試圖在PyTorch
中編寫一個非常簡單的機器翻譯玩具示例。爲了簡單的問題,我打開機器翻譯任務到這一個:隱藏單元在PyTorch中的seq2seq模型中飽和
給定一個隨機序列(
[4, 8, 9 ...]
),預測其元素的元素加1([5, 9, 10, ...]
)的序列。 Id:0, 1, 2
將分別用作pad, bos, eos
。
我在我的機器翻譯任務中發現了這個玩具任務中的同樣的問題。爲了調試,我使用了非常小的數據大小n_data = 50
,並發現該模型可以而不是甚至過度使用這些數據。縱觀該模型,我發現encoder/decoder
即將的隱藏狀態變得飽和,即全部由於tanh
隱藏狀態變得非常接近1/-1
。
-0.8987 0.9634 0.9993 ... -0.8930 -0.4822 -0.9960
-0.9673 1.0000 -0.8007 ... 0.9929 -0.9992 0.9990
-0.9457 0.9290 -0.9260 ... -0.9932 0.9851 0.9980
... ⋱ ...
-0.9995 0.9997 -0.9350 ... -0.9820 -0.9942 -0.9913
-0.9951 0.9488 -0.8894 ... -0.9842 -0.9895 -0.9116
-0.9991 0.9769 -0.5871 ... 0.7557 0.9049 0.9881
另外,無論我怎樣調整學習速度,或這些單元切換到RNN/LSTM/GRU單元,所述損耗值似乎低與50
測試樣品甚至約束。隨着更多的數據,該模型似乎根本不會收斂。
step: 0, loss: 2.313938
step: 10, loss: 1.435780
step: 20, loss: 0.779704
step: 30, loss: 0.395590
step: 40, loss: 0.281261
...
step: 480, loss: 0.231419
step: 490, loss: 0.231410
當我使用tensorflow
,我可以很容易地使用seq2seq模型過度擬合這樣的數據集,並有一個非常小的損耗值。
這裏是什麼我已經試過:
- 手動初始化嵌入到非常小的數字;
- 將漸變裁剪爲固定標準,如1e-2,2,3,5,10;
- 計算損失時,不包括填充索引(通過將
ignore_index
添加到NLLLoss
)。
所有我試過的都沒有幫助解決問題。
我該如何擺脫這個?任何幫助將不勝感激。
下面是代碼,更好的閱讀體驗,它在gist。
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
np.random.seed(0)
torch.manual_seed(0)
_RECURRENT_FN_MAPPING = {
'rnn': torch.nn.RNN,
'gru': torch.nn.GRU,
'lstm': torch.nn.LSTM,
}
def get_recurrent_cell(n_inputs,
num_units,
num_layers,
type_,
dropout=0.0,
bidirectional=False):
cls = _RECURRENT_FN_MAPPING.get(type_)
return cls(
n_inputs,
num_units,
num_layers,
dropout=dropout,
bidirectional=bidirectional)
class Recurrent(nn.Module):
def __init__(self,
num_units,
num_layers=1,
unit_type='gru',
bidirectional=False,
dropout=0.0,
embedding=None,
attn_type='general'):
super(Recurrent, self).__init__()
num_inputs = embedding.weight.size(1)
self._num_inputs = num_inputs
self._num_units = num_units
self._num_layers = num_layers
self._unit_type = unit_type
self._bidirectional = bidirectional
self._dropout = dropout
self._embedding = embedding
self._attn_type = attn_type
self._cell_fn = get_recurrent_cell(num_inputs, num_units, num_layers,
unit_type, dropout, bidirectional)
def init_hidden(self, batch_size):
direction = 1 if not self._bidirectional else 2
h = Variable(
torch.zeros(direction * self._num_layers, batch_size,
self._num_units))
if self._unit_type == 'lstm':
return (h, h.clone())
else:
return h
def forward(self, x, h, len_x):
# Sort by sequence lengths
sorted_indices = np.argsort(-len_x).tolist()
unsorted_indices = np.argsort(sorted_indices).tolist()
x = x[:, sorted_indices]
h = h[:, sorted_indices, :]
len_x = len_x[sorted_indices].tolist()
embedded = self._embedding(x)
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, len_x)
if self._unit_type == 'lstm':
o, (h, c) = self._cell_fn(packed, h)
o, _ = torch.nn.utils.rnn.pad_packed_sequence(o)
return (o[:, unsorted_indices, :], (h[:, unsorted_indices, :],
c[:, unsorted_indices, :]))
else:
o, hh = self._cell_fn(packed, h)
o, _ = torch.nn.utils.rnn.pad_packed_sequence(o)
return (o[:, unsorted_indices, :], hh[:, unsorted_indices, :])
class Encoder(Recurrent):
pass
class Decoder(Recurrent):
pass
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, num_outputs):
super(Seq2Seq, self).__init__()
self._encoder = encoder
self._decoder = decoder
self._out = nn.Linear(decoder._num_units, num_outputs)
def forward(self, x, y, h, len_x, len_y):
# Encode
_, h = self._encoder(x, h, len_x)
# Decode
o, h = self._decoder(y, h, len_y)
# Project
o = self._out(o)
return F.log_softmax(o)
def load_data(size,
min_len=5,
max_len=15,
min_word=3,
max_word=100,
epoch=10,
batch_size=64,
pad=0,
bos=1,
eos=2):
src = [
np.random.randint(min_word, max_word - 1,
np.random.randint(min_len, max_len)).tolist()
for _ in range(size)
]
tgt_in = [[bos] + [xi + 1 for xi in x] for x in src]
tgt_out = [[xi + 1 for xi in x] + [eos] for x in src]
def _pad(batch):
max_len = max(len(x) for x in batch)
return np.asarray(
[
np.pad(
x, (0, max_len - len(x)),
mode='constant',
constant_values=pad) for x in batch
],
dtype=np.int64)
def _len(batch):
return np.asarray([len(x) for x in batch], dtype=np.int64)
for e in range(epoch):
batch_start = 0
while batch_start < size:
batch_end = batch_start + batch_size
s, ti, to = (src[batch_start:batch_end],
tgt_in[batch_start:batch_end],
tgt_out[batch_start:batch_end])
lens, lent = _len(s), _len(ti)
s, ti, to = _pad(s).T, _pad(ti).T, _pad(to).T
yield (Variable(torch.LongTensor(s)),
Variable(torch.LongTensor(ti)),
Variable(torch.LongTensor(to)), lens, lent)
batch_start += batch_size
def print_sample(x, y, yy):
x = x.data.numpy().T
y = y.data.numpy().T
yy = yy.data.numpy().T
for u, v, w in zip(x, y, yy):
print('--------')
print('S: ', u)
print('T: ', v)
print('P: ', w)
n_data = 50
min_len = 5
max_len = 10
vocab_size = 101
n_samples = 5
epoch = 100000
batch_size = 32
lr = 1e-2
clip = 3
emb_size = 50
hidden_size = 50
num_layers = 1
max_length = 15
src_embed = torch.nn.Embedding(vocab_size, emb_size)
tgt_embed = torch.nn.Embedding(vocab_size, emb_size)
eps = 1e-3
src_embed.weight.data.uniform_(-eps, eps)
tgt_embed.weight.data.uniform_(-eps, eps)
enc = Encoder(hidden_size, num_layers, embedding=src_embed)
dec = Decoder(hidden_size, num_layers, embedding=tgt_embed)
net = Seq2Seq(enc, dec, vocab_size)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = torch.nn.NLLLoss()
loader = load_data(
n_data,
min_len=min_len,
max_len=max_len,
max_word=vocab_size,
epoch=epoch,
batch_size=batch_size)
for i, (x, yin, yout, lenx, leny) in enumerate(loader):
net.train()
optimizer.zero_grad()
logits = net(x, yin, enc.init_hidden(x.size()[1]), lenx, leny)
loss = criterion(logits.view(-1, vocab_size), yout.contiguous().view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm(net.parameters(), clip)
optimizer.step()
if i % 10 == 0:
print('step: {}, loss: {:.6f}'.format(i, loss.data[0]))
if i % 200 == 0 and i > 0:
net.eval()
x, yin, yout, lenx, leny = (x[:, :n_samples], yin[:, :n_samples],
yout[:, :n_samples], lenx[:n_samples],
leny[:n_samples])
outputs = net(x, yin, enc.init_hidden(x.size()[1]), lenx, leny)
_, preds = torch.max(outputs, 2)
print_sample(x, yout, preds)
我試圖初始化輸入嵌入到極少數(約1E-4),但什麼都沒有改變數據標準化.. – Edityouprofile
我不認爲你已經使數據更接近於0,但是將範圍從1到-1的數據標準化。你可以使用min-max規範來做到這一點。 – Shehroz
我試過了,但我沒有更新這個問題。對不起。 – Edityouprofile