2017-08-18 60 views
1

我試圖按照這個RNN tutorial on medium一樣,重構它,因爲我一直在。當我運行我的代碼時,它似乎工作,但是當我試圖打印current state變量以查看神經網絡內發生了什麼時,我得到了全部1 s。那是預期的行爲?由於某種原因,州沒有更新?據我所知,current state應該包含隱藏層中所有批次的最新值,所以絕對不應該全部爲1 s。任何幫助將不勝感激。tensorflow隱藏狀態似乎沒有改變

def __train_minibatch__(self, batch_num, sess, current_state): 
    """ 
    Trains one minibatch. 

    :type batch_num: int 
    :param batch_num: the current batch number. 

    :type sess: tensorflow Session 
    :param sess: the session during which training occurs. 

    :type current_state: numpy matrix (array of arrays) 
    :param current_state: the current hidden state 

    :type return: (float, numpy matrix) 
    :param return: (the calculated loss for this minibatch, the updated hidden state) 
    """ 
    start_index = batch_num * self.settings.truncate 
    end_index = start_index + self.settings.truncate 

    batch_x = self.x_train_batches[:, start_index:end_index] 
    batch_y = self.y_train_batches[:, start_index:end_index] 
    total_loss, train_step, current_state, predictions_series = sess.run(
     [self.total_loss_fun, self.train_step_fun, self.current_state, self.predictions_series], 
     feed_dict={ 
      self.batch_x_placeholder:batch_x, 
      self.batch_y_placeholder:batch_y, 
      self.hidden_state:current_state 
     }) 
    return total_loss, current_state, predictions_series 
# End of __train_minibatch__() 

def __train_epoch__(self, epoch_num, sess, current_state, loss_list): 
    """ 
    Trains one full epoch. 

    :type epoch_num: int 
    :param epoch_num: the number of the current epoch. 

    :type sess: tensorflow Session 
    :param sess: the session during training occurs. 

    :type current_state: numpy matrix 
    :param current_state: the current hidden state. 

    :type loss_list: list of floats 
    :param loss_list: holds the losses incurred during training. 

    :type return: (float, numpy matrix) 
    :param return: (the latest incurred lost, the latest hidden state) 
    """ 
    self.logger.info("Starting epoch: %d" % (epoch_num)) 

    for batch_num in range(self.num_batches): 
     # Debug log outside of function to reduce number of arguments. 
     self.logger.debug("Training minibatch : ", batch_num, " | ", "epoch : ", epoch_num + 1) 
     total_loss, current_state, predictions_series = self.__train_minibatch__(batch_num, sess, current_state) 
     loss_list.append(total_loss) 
    # End of batch training 

    self.logger.info("Finished epoch: %d | loss: %f" % (epoch_num, total_loss)) 
    return total_loss, current_state, predictions_series 
# End of __train_epoch__() 

def train(self): 
    """ 
    Trains the given model on the given dataset, and saves the losses incurred 
    at the end of each epoch to a plot image. 
    """ 
    self.logger.info("Started training the model.") 
    self.__unstack_variables__() 
    self.__create_functions__() 
    with tf.Session() as sess: 
     sess.run(tf.global_variables_initializer()) 
     loss_list = [] 

     current_state = np.zeros((self.settings.batch_size, self.settings.hidden_size), dtype=float) 
     for epoch_idx in range(1, self.settings.epochs + 1): 
      total_loss, current_state, predictions_series = self.__train_epoch__(epoch_idx, sess, current_state, loss_list) 
      print("Shape: ", current_state.shape, " | Current output: ", current_state) 
      # End of epoch training 

    self.logger.info("Finished training the model. Final loss: %f" % total_loss) 
    self.__plot__(loss_list) 
    self.generate_output() 
# End of train() 

更新

完成second part of the tutorial和使用後,內置RNN API,問題就消失了,這意味着要麼有毛病我用我的current_state變量,或改變方式tensorflow API導致了一些古怪的事情發生(雖然我很確定它是前者)。如果有人有明確的答案,請留下問題。

回答

0

首先,你應該確保「它看起來工作」是真實的,你的測試錯誤真的變低了。

我的一個假設是,最後一批在最後被零損壞,因爲數據total_series_length/batch_size的長度不是truncated_backprop_length的倍數。 (我沒有檢查它發生了什麼,它充滿了零。本教程中的代碼太舊了,無法在我的tf版本上運行,而且我們也沒有你的代碼。)這個最後的微型批處理在最後只有零可能導致最後的current_state匯聚到所有的。在任何其他小批量current_state不會是所有的。

您可以嘗試每次運行sess.run時在__train_minibatch__中打印current_state。或者可能只是每1000個小批量打印一次。

+0

測試錯誤的確會降低 - 從6.7到3.4。我嘗試了在minibatch級別查看'current_state'的建議,並且我在每個minibatch中都獲得了1。我的代碼是在[github回購](https://github.com/ffrankies/tf-terry),如果它有幫助(我想讓教程在基於文本的數據集上運行,所以我已經必須稍微調整教程代碼)。另外,我確實在矩陣行的末尾有「填充」數據,但即使將數據轉換爲一個長陣列並將其重新整形爲小型塊,問題仍然存在。 – frankie