2017-06-21 63 views
1

我正在訓練tensorflow中的卷積模型。在訓練了大約70個時期的模型之後,花了近1.5個小時,我無法保存模型。它給了我ValueError: GraphDef cannot be larger than 2GB。我發現隨着訓練的進行,圖形中的節點數量不斷增加。Tensorflow:隨着訓練的進行,圖中的節點數量不斷增加

在時代0,3,6,9處,圖中節點的數量分別是7214,7238,7262,7286。當我使用with tf.Session() as sess:時,不是將會話作爲sess = tf.Session()傳遞,而是分別在時期0,3,6,9處的節點數爲3982,4006,4030,4054。

this答案,據說隨着節點被添加到圖中,它可以超過其最大尺寸。我需要幫助瞭解節點數量如何在我的圖表中繼續上升。

def runModel(data): 
    ''' 
    Defines cost, optimizer functions, and runs the graph 
    ''' 
    X, y,keep_prob = modelInputs((755, 567, 1),4) 
    logits = cnnModel(X,keep_prob) 
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y), name="cost") 
    optimizer = tf.train.AdamOptimizer(.0001).minimize(cost) 
    correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1), name="correct_pred") 
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy') 

    sess = tf.Session() 
    sess.run(tf.global_variables_initializer()) 
    saver = tf.train.Saver() 
    for e in range(12): 
     batch_x, batch_y = data.next_batch(30) 
     x = tf.reshape(batch_x, [30, 755, 567, 1]).eval(session=sess) 
     batch_y = tf.one_hot(batch_y,4).eval(session=sess) 
     sess.run(optimizer, feed_dict={X: x, y: batch_y,keep_prob:0.5}) 
     if e%3==0: 
      n = len([n.name for n in tf.get_default_graph().as_graph_def().node]) 
      print("No.of nodes: ",n,"\n") 
      current_cost = sess.run(cost, feed_dict={X: x, y: batch_y,keep_prob:1.0}) 
      acc = sess.run(accuracy, feed_dict={X: x, y: batch_y,keep_prob:1.0}) 
      print("At epoch {epoch:>3d}, cost is {a:>10.4f}, accuracy is {b:>8.5f}".format(epoch=e, a=current_cost, b=acc)) 

什麼原因導致節點的數量增加:

我用下面的代碼訓練我的模型?

+0

也許你可以在每一步獲得新節點的名稱,並查看它們是哪個節點?也許這只是每次被複制的輸入節點,我不知道......你使用的是什麼版本的tf? – gdelab

+0

@gdelab我正在使用'1.0.1',每個時代的節點數似乎都增加了8! – dpk

+0

是的,但是你可以在每一步獲得八個新的節點名稱嗎?也許他們可以幫助理解新節點的創建地點...... – gdelab

回答

2

您正在訓練循環中創建新節點。特別是,您打電話tf.reshapetf.one_hot,其中每個創建一個(或多個)節點。您可以:

  • 使用佔位符作爲輸入在圖的外部創建這些節點,然後僅在循環中對它們進行評估。
  • 對這些操作不使用TensorFlow,而是使用NumPy或等效操作。

我會推薦第二個,因爲在使用TensorFlow進行數據準備時似乎沒有任何好處。你可以有這樣的事情:

import numpy as np 
# ... 
    x = np.reshape(batch_x, [30, 755, 567, 1]) 
    # ... 
    # One way of doing one-hot encoding with NumPy 
    classes_arr = np.arange(4).reshape([1] * batch_y.ndims + [-1]) 
    batch_y = (np.expand_dims(batch_y, -1) == classes_arr).astype(batch_y.dtype) 
    # ... 

PD:我也建議在withcontext manager使用tf.Session(),以確保其close()方法在最後被調用(除非您想以後使用同一個會話保持)。