2016-05-03 54 views
4

張量流中的embedding_rnn_seq2seq函數seq2seq模塊提供feed_previous參數,這意味着在解碼過程中它只使用第一個解碼器輸入,然後在後續解碼器輸入中使用先前的解碼器輸出。有沒有簡單的方法從basic_rnn_seq2seq函數中獲取此行爲?張量流的沒有feed_previous參數basic_rnn_seq2seq函數

回答

0

此API現在已過時,但如果有人仍然在尋找我建議在看這個GitHub庫的解決方案:raindeer/seq2seq_experiments

總之,創建自己的解碼器,筆者採用是loop_function以下(重要組成部分):

def _init_seq2seq(self, encoder_inputs, decoder_inputs, cell, feed_previous): 

    def inference_loop_function(prev, _): 
     prev = tf.nn.xw_plus_b(prev, self.w_softmax, self.b_softmax) 
     return tf.to_float(tf.equal(prev, tf.reduce_max(prev, reduction_indices=[1], keep_dims=True))) 

    loop_function = inference_loop_function if feed_previous else None 

    with variable_scope.variable_scope('seq2seq'): 
     _, final_enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtypes.float32) 
     return seq2seq.rnn_decoder(decoder_inputs, final_enc_state, cell, loop_function=loop_function) 
相關問題