2016-06-28 133 views
3

我想用大型數據集來訓練CNN。目前我將所有數據加載到tf.constant中,然後在tf.Session()中以小批量大小進行循環。這爲數據集的一小部分工作正常,但是當我加大投入大小我得到的錯誤:在Tensorflow中使用大型數據集

ValueError: Cannot create a tensor proto whose content is larger than 2GB. 

我怎樣才能避免這種情況?

+0

[初始化tensorflow與可變容量大於2GB的數組]的可能的複製(https://開頭計算器的.com /問題/ 35394103 /初始化-tensorflow變量與 - 一個陣列-幅度超出2GB) – Steven

回答

5

不要將數據加載到常量,它將成爲您的計算圖的一部分。

你還是:

  • 創建它加載你的數據流的方式
  • 加載數據在Python部分運算,並使用feed_dict到批量進入圖形
3

對於TensorFlow 1.x和Python 3,有我的簡單解決方案:

X_init = tf.placeholder(tf.float32, shape=(m_input, n_input)) 
X = tf.Variable(X_init) 
sess.run(tf.global_variables_initializer(), feed_dict={X_init: data_for_X}) 

在實踐中,您將主要指定圖形和會話連續計算,這下面的代碼將幫助您:

my_graph = tf.Graph() 
sess = tf.Session(graph=my_graph) 
with my_graph.as_default(): 
    X_init = tf.placeholder(tf.float32, shape=(m_input, n_input)) 
    X = tf.Variable(X_init) 
    sess.run(tf.global_variables_initializer(), feed_dict={X_init: data_for_X}) 
    .... # build your graph with X here 
.... # Do some other things here 
with my_graph.as_default(): 
    output_y = sess.run(your_graph_output, feed_dict={other_placeholder: other_data})