2016-08-11 23 views
3

MNIST集合包含60000個用於訓練集的圖像。在訓練我的Tensorflow時,我想運行訓練步驟來訓練整個訓練集的模型。 Tensorflow網站上的深度學習示例使用20000次迭代,批量大小爲50(總計爲1,000,000批次)。當我嘗試超過30,000次迭代時,我的數字預測失敗(預測所有手寫數字爲0)。我的問題是,我應該使用多少次迭代,批量大小爲50來訓練整個MNIST集的張量流模型?用整個MNIST數據集(60000圖像)訓練張量流所需的迭代次數?

self.mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 
for i in range(FLAGS.training_steps): 
    batch = self.mnist.train.next_batch(50) 
    self.train_step.run(feed_dict={self.x: batch[0], self.y_: batch[1], self.keep_prob: 0.5}) 
    if (i+1)%1000 == 0: 
     saver.save(self.sess, FLAGS.checkpoint_dir + 'model.ckpt', global_step = i) 

回答

2

我認爲這取決於您的停止標準。如果損失沒有改善,您可以停止培訓,或者您可以擁有驗證數據集,並在驗證準確性無法再提高時停止培訓。

+0

我想我會做到這一點。可能在每1000次迭代中,我會盡量準確。如果在某個時候,積分下降到0,我應該停止在那裏下雨。 – Swapnil

1

隨着機器學習,你往往會有嚴重的收益遞減情況。例如這裏是從我的細胞神經網絡的一個準確的列表:

Epoch 0 current test set accuracy : 0.5399 
Epoch 1 current test set accuracy : 0.7298 
Epoch 2 current test set accuracy : 0.7987 
Epoch 3 current test set accuracy : 0.8331 
Epoch 4 current test set accuracy : 0.8544 
Epoch 5 current test set accuracy : 0.8711 
Epoch 6 current test set accuracy : 0.888 
Epoch 7 current test set accuracy : 0.8969 
Epoch 8 current test set accuracy : 0.9064 
Epoch 9 current test set accuracy : 0.9148 
Epoch 10 current test set accuracy : 0.9203 
Epoch 11 current test set accuracy : 0.9233 
Epoch 12 current test set accuracy : 0.929 
Epoch 13 current test set accuracy : 0.9334 
Epoch 14 current test set accuracy : 0.9358 
Epoch 15 current test set accuracy : 0.9395 
Epoch 16 current test set accuracy : 0.942 
Epoch 17 current test set accuracy : 0.9436 
Epoch 18 current test set accuracy : 0.9458 

正如你所看到的收益開始下降後〜10個曆元*,但是這可能會因您的網絡和學習速度上。基於多少時間你有多少時間有好處做的不盡相同,但我發現20是一個合理的數字

*我一直使用時代這個詞來表示一個整個運行通過一個數據集但我不知道是該定義的準確性,這裏每個時代爲〜帶的大小批量429個訓練步128

0

您可以使用類似no_improve_epoch並將其設置爲假設3.什麼便索性意味着如果在3次迭代中沒有> 1%的改善,則停止迭代。

no_improve_epoch= 0 
     with tf.Session() as sess: 
      sess.run(cls.init) 
      if cls.config.reload=='True': 
       print(cls.config.reload) 
       cls.logger.info("Reloading the latest trained model...") 
       saver.restore(sess, cls.config.model_output) 
      cls.add_summary(sess) 
      for epoch in range(cls.config.nepochs): 
       cls.logger.info("Epoch {:} out of {:}".format(epoch + 1, cls.config.nepochs)) 
       dev = train 
       acc, f1 = cls.run_epoch(sess, train, dev, tags, epoch) 

       cls.config.lr *= cls.config.lr_decay 

       if f1 >= best_score: 
        nepoch_no_imprv = 0 
        if not os.path.exists(cls.config.model_output): 
         os.makedirs(cls.config.model_output) 
        saver.save(sess, cls.config.model_output) 
        best_score = f1 
        cls.logger.info("- new best score!") 

       else: 
        no_improve_epoch+= 1 
        if nepoch_no_imprv >= cls.config.nepoch_no_imprv: 
         cls.logger.info("- early stopping {} Iterations without improvement".format(
          nepoch_no_imprv)) 
         break 

Sequence Tagging GITHUB

相關問題