我正在使用鋼筋學習,並希望在訓練期間減少通過sess.run()饋送的數據量,以加快學習速度。TensorFlow:LSTM狀態保存/圖內更新
我一直在尋找進入LSTM並與需要向前看,重新找到正確的Q值,我精心設計了一個解決方案,如本與tf.case():
CurrentStateOption = tf.Variable(0, trainable=False, name='SavedState')
with tf.name_scope("LSTMLayer") as scope:
initializer = tf.random_uniform_initializer(-.1, .1)
lstm_cell_L1 = tf.nn.rnn_cell.LSTMCell(self.input_sizes, forget_bias=1.0, initializer=initializer, state_is_tuple=True)
self.cell_L1 = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_L1] *self.NumberLSTMLayers, state_is_tuple=True)
self.state = self.cell_L1.zero_state(1,tf.float64)
self.SavedState = self.cell_L1.zero_state(1,tf.float64) #tf.Variable(state, trainable=False, name='SavedState')
#SaveCond = tf.cond(tf.equal(CurrentStateOption,tf.constant(1)), self.SaveState, self.SameState)
#RestoreCond = tf.cond(tf.equal(CurrentStateOption,tf.constant(-1)), self.RestoreState, self.SameState)
#ZeroCond = tf.cond(tf.less(CurrentStateOption,tf.constant(-1)), self.ZeroState, self.SameState)
self.state = tf.case({tf.equal(CurrentStateOption,tf.constant(1)): self.SaveState, tf.equal(CurrentStateOption,tf.constant(-1)): self.RestoreState,
tf.less(CurrentStateOption,tf.constant(-1)): self.ZeroState}, default=self.SameState, exclusive=True)
RunConditions = tf.group([SaveCond, RestoreCond, ZeroCond])
self.Xinputs = [tf.concat(1,[Xinputs])]
outputs, stateFINAL_L1 = rnn.rnn(self.cell_L1,self.Xinputs, initial_state=self.state, dtype=tf.float32)
def RestoreState(self): #self.state = self.state.assign(self.SavedState) self.state = self.SavedState return self.state def ZeroState(self): self.state = self.cell_L1.zero_state(1,tf.float64) return self.state def SaveState(self): #self.SavedState = self.SavedState.assign(self.state) self.SavedState = self.state return self.SavedState def SameState(self): return self.state
這似乎在概念現在好工作,我可以養活一個INT指示LSTM圖形做什麼w^ith狀態。如果我傳遞「1」,它會在執行前保存狀態,如果我傳遞「-1」,它會恢復上次保存的狀態,如果我通過「< -1」,它將使狀態歸零。如果「0」它將使用上次運行(推理)中的LSTM中的內容。我嘗試了幾種不同的方法,包括一個更簡單的tf.cond()方法。
我認爲這個問題源於tf.case()Op需要張量,但LSTM狀態是一個元組(並且非元組將被折舊)。當我嘗試tf.assign()該值到圖變量時,這變得很清楚。
我的最終目標是在圖表中留下「狀態」,但通過一個INT指示如何處理狀態。在未來,我希望有多個「商店」的地點爲各種回顧。
任何想法如何處理tf.case()類型的結構與元組與張量?