2016-09-06 76 views
0

我正在使用鋼筋學習,並希望在訓練期間減少通過sess.run()饋送的數據量,以加快學習速度。TensorFlow:LSTM狀態保存/圖內更新

我一直在尋找進入LSTM並與需要向前看,重新找到正確的Q值,我精心設計了一個解決方案,如本與tf.case():

CurrentStateOption = tf.Variable(0, trainable=False, name='SavedState') 
with tf.name_scope("LSTMLayer") as scope: 
     initializer = tf.random_uniform_initializer(-.1, .1) 
     lstm_cell_L1 = tf.nn.rnn_cell.LSTMCell(self.input_sizes, forget_bias=1.0, initializer=initializer, state_is_tuple=True) 
     self.cell_L1 = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_L1] *self.NumberLSTMLayers, state_is_tuple=True) 
     self.state = self.cell_L1.zero_state(1,tf.float64) 

     self.SavedState = self.cell_L1.zero_state(1,tf.float64) #tf.Variable(state, trainable=False, name='SavedState') 

     #SaveCond = tf.cond(tf.equal(CurrentStateOption,tf.constant(1)), self.SaveState, self.SameState) 
     #RestoreCond = tf.cond(tf.equal(CurrentStateOption,tf.constant(-1)), self.RestoreState, self.SameState) 
     #ZeroCond = tf.cond(tf.less(CurrentStateOption,tf.constant(-1)), self.ZeroState, self.SameState) 

     self.state = tf.case({tf.equal(CurrentStateOption,tf.constant(1)): self.SaveState, tf.equal(CurrentStateOption,tf.constant(-1)): self.RestoreState, 
      tf.less(CurrentStateOption,tf.constant(-1)): self.ZeroState}, default=self.SameState, exclusive=True) 

     RunConditions = tf.group([SaveCond, RestoreCond, ZeroCond]) 

     self.Xinputs = [tf.concat(1,[Xinputs])] 

     outputs, stateFINAL_L1 = rnn.rnn(self.cell_L1,self.Xinputs, initial_state=self.state, dtype=tf.float32) 
def RestoreState(self): 
    #self.state = self.state.assign(self.SavedState) 
    self.state = self.SavedState 
    return self.state 
def ZeroState(self): 
    self.state = self.cell_L1.zero_state(1,tf.float64) 
    return self.state 
def SaveState(self): 
    #self.SavedState = self.SavedState.assign(self.state) 
    self.SavedState = self.state 
    return self.SavedState 
def SameState(self): 
    return self.state 

這似乎在概念現在好工作,我可以養活一個INT指示LSTM圖形做什麼w^ith狀態。如果我傳遞「1」,它會在執行前保存狀態,如果我傳遞「-1」,它會恢復上次保存的狀態,如果我通過「< -1」,它將使狀態歸零。如果「0」它將使用上次運行(推理)中的LSTM中的內容。我嘗試了幾種不同的方法,包括一個更簡單的tf.cond()方法。

我認爲這個問題源於tf.case()Op需要張量,但LSTM狀態是一個元組(並且非元組將被折舊)。當我嘗試tf.assign()該值到圖變量時,這變得很清楚。

我的最終目標是在圖表中留下「狀態」,但通過一個INT指示如何處理狀態。在未來,我希望有多個「商店」的地點爲各種回顧。

任何想法如何處理tf.case()類型的結構與元組與張量?

回答

0

我相信有狀態元組中的每個元素應該有一個tf.case(),因爲元組只是一個python元組。