2017-06-05 78 views
0

更新:我堅信,作爲創建並送入tf.nn.dynamic_rnn(...)作爲參數錯誤是關係到init_state。因此,問題就變成了堆棧RNN的初始狀態的正確形式或構造方式是什麼?Tensorflow 1.1 MultiRNNCell形狀誤差(Init_State相關)

我想獲得一個MultiRNNCell定義在TensorFlow 1.1中工作。

圖形定義(帶輔助函數定義GRU單元)如下所示。基本思想是將佔位符x定義爲數字數據樣本的冗長字符串。這些數據將通過整形分成相等長度的幀,每個時間步將顯示一幀。然後,我想通過GRU的兩個(現在)單元來處理這個問題。

def gru_cell(state_size): 
    cell = tf.contrib.rnn.GRUCell(state_size) 
    return cell 

graph = tf.Graph() 
with graph.as_default(): 

    x = tf.placeholder(tf.float32, [batch_size, num_samples], name="Input_Placeholder") 
    y = tf.placeholder(tf.int32, [batch_size, num_frames], name="Labels_Placeholder") 

    init_state = tf.zeros([batch_size, state_size], name="Initial_State_Placeholder") 

    rnn_inputs = tf.reshape(x, (batch_size, num_frames, frame_length)) 
    cell = tf.contrib.rnn.MultiRNNCell([gru_cell(state_size) for _ in range(2)], state_is_tuple=False) 
    rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs, initial_state=init_state) 

圖形定義從那裏繼續與損失函數,優化等,但是,這就是它打破了具有以下冗長錯誤的地方。

它將成爲那個是的batch_size 10,和幀_和state_size都是80

ValueError        Traceback (most recent call last) 
<ipython-input-30-4c48b596e055> in <module>() 
    14  print(rnn_inputs) 
    15  cell = tf.contrib.rnn.MultiRNNCell([gru_cell(state_size) for _ in range(2)], state_is_tuple=False) 
---> 16  rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs, initial_state=init_state) 
    17 
    18  with tf.variable_scope('softmax'): 

/home/novak/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/rnn.pyc in dynamic_rnn(cell, inputs, sequence_length, initial_state, dtype, parallel_iterations, swap_memory, time_major, scope) 
    551   swap_memory=swap_memory, 
    552   sequence_length=sequence_length, 
--> 553   dtype=dtype) 
    554 
    555  # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. 

/home/novak/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/rnn.pyc in _dynamic_rnn_loop(cell, inputs, initial_state, parallel_iterations, swap_memory, sequence_length, dtype) 
    718  loop_vars=(time, output_ta, state), 
    719  parallel_iterations=parallel_iterations, 
--> 720  swap_memory=swap_memory) 
    721 
    722 # Unpack final output if not using output tuples. 

/home/novak/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.pyc in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name) 
    2621  context = WhileContext(parallel_iterations, back_prop, swap_memory, name) 
    2622  ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, context) 
-> 2623  result = context.BuildLoop(cond, body, loop_vars, shape_invariants) 
    2624  return result 
    2625 

/home/novak/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.pyc in BuildLoop(self, pred, body, loop_vars, shape_invariants) 
    2454  self.Enter() 
    2455  original_body_result, exit_vars = self._BuildLoop(
-> 2456   pred, body, original_loop_vars, loop_vars, shape_invariants) 
    2457  finally: 
    2458  self.Exit() 

/home/novak/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.pyc in _BuildLoop(self, pred, body, original_loop_vars, loop_vars, shape_invariants) 
    2435  for m_var, n_var in zip(merge_vars, next_vars): 
    2436  if isinstance(m_var, ops.Tensor): 
-> 2437   _EnforceShapeInvariant(m_var, n_var) 
    2438 
    2439  # Exit the loop. 

/home/novak/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.pyc in _EnforceShapeInvariant(merge_var, next_var) 
    565   "Provide shape invariants using either the `shape_invariants` " 
    566   "argument of tf.while_loop or set_shape() on the loop variables." 
--> 567   % (merge_var.name, m_shape, n_shape)) 
    568 else: 
    569  if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)): 

ValueError: The shape for rnn/while/Merge_2:0 is not an invariant for the loop. It enters the loop with shape (10, 80), but has shape (10, 160) after one iteration. Provide shape invariants using either the `shape_invariants` argument of tf.while_loop or set_shape() on the loop variables. 

,幾乎看起來像網絡開始作爲80年代的2堆棧錯誤的最後部分相關並以某種方式被轉換爲160堆棧的一堆。任何幫助解決這個問題?我誤解了MultiRNNCell的使用嗎?

+0

不應該是'init_state = tf.zeros([batch_size,2 * state_size] ...''? –

回答

0

基於艾倫拉沃伊的評論上方,更正後的代碼是:

def gru_cell(state_size): 
    cell = tf.contrib.rnn.GRUCell(state_size) 
    return cell 

num_layers = 2 # <--------- 
graph = tf.Graph() 
with graph.as_default(): 

    x = tf.placeholder(tf.float32, [batch_size, num_samples], name="Input_Placeholder") 
    y = tf.placeholder(tf.int32, [batch_size, num_frames], name="Labels_Placeholder") 

    init_state = tf.zeros([batch_size, num_layer * state_size], name="Initial_State_Placeholder") # <--------- 

    rnn_inputs = tf.reshape(x, (batch_size, num_frames, frame_length)) 
    cell = tf.contrib.rnn.MultiRNNCell([gru_cell(state_size) for _ in range(num_layer)], state_is_tuple=False) # <--------- 
    rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs, initial_state=init_state) 

注意上述三個變化。 還要注意,這些更改必須波及到init_state流向的任何地方,特別是如果您將它們饋送到feed_dict中。