1
我火車LSTM網絡如何從vanila Tensorflow中的LSTM單元中提取所有權重?
cell_fw = tf.contrib.rnn.BasicLSTMCell(HIDDEN_SIZE)
cell_bw = tf.contrib.rnn.BasicLSTMCell(HIDDEN_SIZE)
rnn_outputs, final_state_fw, final_state_bw = tf.contrib.rnn.static_bidirectional_rnn(
cell_fw=cell_fw,
cell_bw=cell_bw,
inputs=rnn_inputs,
dtype=tf.float32
)
此外,我嘗試將其保存係數:
d = {}
with tf.Session() as sess:
# train code ...
variables_names =[v.name for v in tf.global_variables()]
values = sess.run(variables_names)
for k,v in zip(variables_names, values):
d[k] = v
字典d必須從每個LSTM細胞只有2個對象:
[(k,v.shape) for (k,v) in sorted(d.items(), key=lambda x:x[0])]
[('bidirectional_rnn/bw/basic_lstm_cell/biases:0', (1024,)),
('bidirectional_rnn/bw/basic_lstm_cell/weights:0', (272, 1024)),
('bidirectional_rnn/fw/basic_lstm_cell/biases:0', (1024,)),
('bidirectional_rnn/fw/basic_lstm_cell/weights:0', (272, 1024)),
('char_embedding:0', (70, 16)),
('softmax_biases:0', (5068,)),
('softmax_weights:0', (5068, 512))]
我我感到困惑。每個LSTM單元應該包含多達4個可訓練層,或者不是?如果是這樣,如何從LSTM單元獲得所有權重?
哦,這是真的。謝謝,我可以放鬆。 – Roosh