我有一個複雜的網絡的一個人爲的版本:爲什麼這個tensorflow循環需要這麼多的內存?
import tensorflow as tf
a = tf.ones([1000])
b = tf.ones([1000])
for i in range(int(1e6)):
a = a * b
我的直覺是,這應該需要很少的內存。只是初始數組分配的空間和一系列利用節點的命令並在每一步中覆蓋存儲在張量'a'中的內存。但內存使用增長相當迅速。
這裏發生了什麼,當我計算張量並覆蓋很多次時,如何減少內存使用量?
編輯:
由於雅羅斯拉夫的建議解決方案橫空出世使用while_loop以儘量減少對圖中的節點的數量要。這樣做效果很好,速度更快,所需內存更少,並且全部包含在圖表中。
import tensorflow as tf
a = tf.ones([1000])
b = tf.ones([1000])
cond = lambda _i, _1, _2: tf.less(_i, int(1e6))
body = lambda _i, _a, _b: [tf.add(_i, 1), _a * _b, _b]
i = tf.constant(0)
output = tf.while_loop(cond, body, [i, a, b])
with tf.Session() as sess:
result = sess.run(output)
print(result)
這很有道理。我只想調用sess.run一次,並且一步處理從輸入到輸出的所有計算,因爲我正在反向傳播大圖。是否可以在不添加額外節點的情況下執行此循環? – jstaker7
如果您在'tf.Variable'對象中將輸入保存到* b中,您將隔離它的依賴關係,因此您可以在該節點上執行'sess.run一百萬次,而無需評估圖中的任何其他內容 –
不確定我很瞭解。假設'b'是一個可訓練的權重矩陣,在back-prop期間得到更新。如果我多次調用sess.run,不但會招致多次sess.run調用的額外開銷,而且會將該計算與梯度計算斷開連接,並需要做一些繁瑣的工作以確保其得到正確更新。這些假設是否正確?我想我希望有更好的方法來處理這個圖表。 – jstaker7