2017-02-23 155 views
2

我想用python構建和訓練多層LSTM模型(stateIsTuple = True),然後在C++中加載和使用它。但我很難弄清楚如何在C++中提供和獲取狀態,主要是因爲我沒有可以引用的字符串名稱。如何在tensorflow中輸入和檢索LSTM的狀態C/C++

E.g.我把初始狀態作爲

with tf.name_scope('rnn_input_state'): 
     self.initial_state = cell.zero_state(args.batch_size, tf.float32) 

並且這如以下將顯示在圖命名範圍,例如,但如何供給到這些在C++?

enter image description here

另外,我怎麼能取C++中的現狀如何?我在python中嘗試了下面的圖形構建代碼,但我不確定它是否是正確的做法,因爲last_state應該是一個張量元組,而不是一個張量(儘管我可以看到tensorboard中的last_state節點是2x2x50x128,這聽起來像它只是連接狀態,因爲我有2層,128個大小,50個小批量大小和lstm單元格 - 帶有2個狀態向量)。

with tf.name_scope('outputs'): 
     outputs, last_state = legacy_seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None) 
     output = tf.reshape(tf.concat(outputs, 1), [-1, args.rnn_size], name='output') 

,這就是它看起來像在tensorboard

enter image description here

我應該Concat的,分裂國家的張量因此只有過一個狀態張量進進出出?或者,還有更好的方法?

P.S.理想情況下,解決方案不涉及硬編碼層數(或尺寸)。所以我可以只有四個字符串input_node_name,output_node_name,input_state_name,output_state_name,剩下的就是從那裏派生出來的。

回答

3

我設法做到這一點,通過手動將狀態連接成單張量。我不確定這是否明智,因爲這是張量流如何使用來處理狀態,但現在不贊成這一點,並切換到元組狀態。我沒有設置state_is_tuple = False,而且冒着冒着代碼被淘汰的風險,我添加了額外的操作來手動堆棧和從單個張量中解散狀態。說這個,它在python和C++中都能正常工作。

關鍵代碼:

# setting up 
zero_state = cell.zero_state(batch_size, tf.float32) 
state_in = tf.identity(zero_state, name='state_in')   

# based on https://medium.com/@erikhallstrm/using-the-tensorflow-multilayered-lstm-api-f6e7da7bbe40#.zhg4zwteg 
state_per_layer_list = tf.unstack(state_in, axis=0) 
state_in_tuple = tuple(
    # TODO make this not hard-coded to LSTM 
    [tf.contrib.rnn.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1]) 
    for idx in range(num_layers)] 
) 

outputs, state_out_tuple = legacy_seq2seq.rnn_decoder(inputs, state_in_tuple, cell, loop_function=loop if infer else None) 
state_out = tf.identity(state_out_tuple, name='state_out') 

# running (training or inference) 
state = sess.run('state_in:0') # zero state 

loop: 
    feed = {'data_in:0': x, 'state_in:0': state} 
    [y, state] = sess.run(['data_out:0', 'state_out:0'], feed) 

以下是完整的代碼,如果有人需要它 https://github.com/memo/char-rnn-tensorflow