2017-06-02 31 views
0

以下代碼使用tf.while_loop(...)來計算動態長度。tf.while_loop輸出中堆疊tensorArray的未知大小

outputs_tensor_array = tf.TensorArray(tf.float32, 
              size=0, 
              clear_after_read=False, 
              infer_shape=False, 
              dynamic_size = True, 
              element_shape[self.batch_size, self.size]) 

    initial_args = [outputs_tensor_array, 0] 
    outputs, *_ = tf.while_loop(lambda out, idx, *_ : idx < max_len, 
           func, 
           initial_args + additional_args, 
           parallel_iterations = 32, 
           swap_memory = True) 
    outputs = outputs.stack() 

我想知道,如果可能強制執行的大小,或至少使該尺寸爲None,以強制執行大小限制,以及進一步計算下來的圖。目前的形狀是[?, batch, hidden_size]

回答

1

tensor.set_shape將細化靜態形狀的信息,如果它是與當前的靜態形狀信息(在TensorArray.stack()不兼容的情況下拋出一個錯誤,它會讓你在零維的靜態設置的任何值形狀信息)。

tf.reshape對於聲明/填充形狀信息也是有用的,雖然它不完美。只有當圖表執行時張量的大小是錯誤的(否則可能會隱藏下游的形狀誤差),它只會拋出一個錯誤。

更復雜,但您也可以使用set_shape獲取靜態形狀信息,然後使用tf.Asserttf.shape檢查圖形執行時的張量形狀。

相關問題