2015-12-21 86 views
2
def rnn_seq2seq(encoder_inputs, decoder_inputs, cell, output_projection=None,feed_previous=False, dtype=tf.float32, scope=None): 
    with tf.variable_scope(scope or "rnn_seq2seq"): 
    _, enc_states = rnn.rnn(cell, encoder_inputs,dtype=dtype) 


    def extract_argmax(prev, i): 
    if output_projection is not None: 
     prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1]) 
    return tf.to_float(tf.equal(prev,tf.reduce_max(prev,reduction_indices=[1],keep_dims=True))) 

    loop_function = None 
    if feed_previous: 
    loop_function = extract_argmax 

     #seq2seq.rnn_decoder is provided in tensorflow/models/rnn/seq2seq.py 
    return seq2seq.rnn_decoder(decoder_inputs, enc_states[-1], cell, loop_function=loop_function) 

我想創建兩個RNN模型,一個用於訓練,另一個用於測試。爲此,我可以調用兩次將feed_previous傳遞給True或False的函數。Tensorflow seq2seq權重共享

train_op,train_states = rnn_seq2seq(enc_inp,dec_inp,cell,output_projection=op,feed_previous=False) 
test_op,_ = rnn_seq2seq(enc_inp,dec_inp,cell,output_projection=op,feed_previous=True) 

但是,如果我調用上述函數兩次,它不會創建兩個不同的RNN?我想知道他們是否能夠分享權重。

回答

1

兩個功能在同一個默認的圖形操作等可重複使用的變量,檢查出variable scopes tutorial,看看你的變量與reuse=True參數

作爲一個全面的檢查創建,嘗試下面的代碼片段列出的所有變量默認圖:

[v.name for v in tf.get_default_graph().as_graph_def().node if v.op=='Variable'] 
+0

我確實已經閱讀了教程,但我的理解可能有差距。 我將兩個參數傳遞給需要共享的rnn_seq2seq - 「cell」和「output_projection」。因爲我傳遞的是同樣的東西,我認爲他們會被分享。 但我懷疑在tensorflow中提供的seq2seq_decoder函數。 –

+0

我已經縮小了我的疑問: 創建GRUCell時,rnn_decoder調用 「output,new_state = cell(inp,states [-1])」 這是否意味着它具有指向狀態的指針,還是創建一個新的狀態一共只複製價值? –