我正在嘗試關於可變長度多變量序列分類問題的RNN。在張量流中,如何迭代存儲在張量中的一系列輸入?
我已經定義下面的函數來獲得該序列的輸出(即RNN細胞的從序列中的最後輸入後的輸出被饋送)
def get_sequence_output(x_sequence, initial_hidden_state):
previous_hidden_state = initial_hidden_state
for x_single in x_sequence:
hidden_state = gru_unit(previous_hidden_state, x_single)
previous_hidden_state = hidden_state
final_hidden_state = hidden_state
return final_hidden_state
這裏x_sequence
是形狀(?, ?, 10)
的張量,其中第一?是批量大小和秒?用於序列長度,每個輸入元素的長度爲10. gru
函數採用先前的隱藏狀態和當前輸入並吐出下一個隱藏狀態(標準門控循環單元)。
我得到一個錯誤:'Tensor' object is not iterable.
如何按順序方式遍歷張量(一次讀取單個元素)?
我的目標是爲序列中的每個輸入應用gru
函數,並獲得最終的隱藏狀態。
如果time_steps不存在(序列的可變長度)這將無法正常工作。 – Cospel