2017-11-04 85 views
0

我打算用C++寫一個BasicLSTMCell,我需要檢查它是否工作正常。 我用tf.nn.rnn_cell.BasicLSTMCell實現LSTM有4個細胞,我設置 forget_bias到1.Then我使用此代碼檢查LSTM'bias:Issue to forget_bias

////////////////////////////////////////////////////////////// 

    with tf.variable_scope("LSTM"): 
    Cell=tf.nn.rnn_cell.BasicLSTMCell(4,forget_bias=1,state_is_tuple=True) 
Sessin=tf.Session() 
state=Cell.zero_state(1,dtype=tf.float32) 
with tf.variable_scope("Ut_def"): 
    out,D=tf.nn.dynamic_rnn(
      cell=Cell,inputs=Feed, 
      initial_state=state, 
      time_major=False) 
Sessin.run(tf.global_variables_initializer()) 
#Saver.save(Sessin,"./123/Var",global_step=1) 
out,D=Sessin.run([out,D],feed_dict={Feed:np.arange(8).reshape(1,2,4)}) 
tf.train.Saver().save(Sessin,"./123/Var",global_step=1) 
trainable_vars_dict = {} 
for key in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): 
    trainable_vars_dict[key.name] = Sessin.run(key) 
    # Checking the names of the keys 
    print(key.name) 
lstm_weight_vals = trainable_vars_dict["Ut_def/RNN/BasicLSTMCell/Linear/Matrix:0"] 
B=trainable_vars_dict["Ut_def/RNN/BasicLSTMCell/Linear/Bias:0"] 
print(B) 
///////////////////////////////////////////////////////////// 

但我發現,這些偏見都是零不管我改變了forget_bias

有人知道它是怎麼回事?

爲了搞清楚lstm是如何工作的,我只是用tensorflow中的wights和bias來獲得相同的結果。絕對不是等同的。

w_i, w_C, w_f, w_o = np.split(lstm_weight_vals, 4, axis=1) 
w_xi = w_i[:4, :] 
w_hi = w_i[4:, :] 
w_xC = w_C[:4, :] 
w_hC = w_C[4:, :] 
w_xf = w_f[:4, :] 
w_hf = w_f[4:, :] 
w_xo = w_o[4:, :] 
w_ho = w_o[4:, :] 
Input=tf.range(4,dtype=tf.float32) 
Input=tf.reshape(Input,shape=[1,4]) 
i=tf.sigmoid(tf.matmul(tf.zeros(shape=[1,4]),w_xi)+tf.matmul(Input,w_hi)) 
o=tf.sigmoid(tf.matmul(tf.zeros(shape=[1,4]),w_xo)+tf.matmul(Input,w_ho)) 
g=tf.tanh(tf.matmul(tf.zeros(shape=[1,4]),w_xC)+tf.matmul(Input,w_hC)) 
f=tf.sigmoid(tf.matmul(tf.zeros(shape=[1,4]),w_xf)+tf.matmul(Input,w_hf)) 
Cstate=tf.zeros(shape=[1,4])*f+i*g 
Hstate=tf.tanh(Cstate)*o 
Input=Input+4 
i=tf.sigmoid(tf.matmul(Cstate,w_xi)+tf.matmul(Input,w_hi)) 
o=tf.sigmoid(tf.matmul(Cstate,w_xo)+tf.matmul(Input,w_ho)) 
g=tf.tanh(tf.matmul(Cstate,w_xC)+tf.matmul(Input,w_hC)) 
f=tf.sigmoid(tf.matmul(Cstate,w_xf)+tf.matmul(Input,w_hf)) 
Cstate=Cstate*f+i*g 

Hstate=tf.tanh(Cstate)*o 

回答

0

我發現了錯誤的代碼。代碼應該是這樣的:

i=tf.sigmoid(tf.matmul(Hstate,w_xi)+tf.matmul(Input,w_hi)) 
o=tf.sigmoid(tf.matmul(Hstate,w_xo)+tf.matmul(Input,w_ho)) 
g=tf.tanh(tf.matmul(Hstate,w_xC)+tf.matmul(Input,w_hC)) 
f=tf.sigmoid(tf.matmul(Hstate,w_xf)+tf.matmul(Input,w_hf)+1) 

這是Hstate而不是Csatae