2017-08-30 64 views

我有一個帶有一些相對較大的卷積過濾器的tensorflow模型。我發現tf.nn.conv2d對於這樣的大型過濾器變得無法使用 - 它試圖使用超過60GB的內存,此時我需要殺死它。這裏是最小的腳本來重現我的錯誤:帶有大型過濾器的tensorflow conv2d的內存使用情況

import tensorflow as tf 
import numpy as np 

frames, height, width, channels = 200, 321, 481, 1 
filter_h, filter_w, filter_out = 5, 5, 3 # With this, output has shape (200, 317, 477, 3) 
# filter_h, filter_w, filter_out = 7, 7, 3 # With this, output has shape (200, 315, 475, 3) 
# filter_h, filter_w, filter_out = 135, 135, 3 # With this, output will be smaller than the above with shape (200, 187, 347, 3), but memory usage explodes 

images = np.random.randn(frames, height, width, channels).astype(np.float32) 

filters = tf.Variable(np.random.randn(filter_h, filter_w, channels, filter_out).astype(np.float32)) 
images_input = tf.placeholder(tf.float32) 
conv = tf.nn.conv2d(images_input, filters, strides=[1, 1, 1, 1], padding="VALID") 

with tf.Session() as sess: 
    result = sess.run(conv, feed_dict={images_input: images}) 

print result.shape 

首先,任何人都可以解釋這種行爲?爲什麼內存使用會隨着過濾器大小而變大? (注:我也試過圍繞改變我的尺寸,使用單個conv3d,而不是一批conv2d S,但有同樣的問題)



  1. Flattens the filter to a 2-D matrix with shape [filter_height * filter_width * in_channels, output_channels] .
  2. Extracts image patches from the input tensor to form a virtual tensor of shape [batch, out_height, out_width, filter_height * filter_width * in_channels] .
  3. For each patch, right-multiplies the filter matrix and the image patch vector.



實際上,當你在CPU上進行調試時會這樣做 - 這是因爲它將卷積轉換爲矩陣乘法,所以可以使用線性代數庫來實現它。在GPU上運行它可能需要更少的內存。 – etarion



I had originally taken this simply as a description of the process, but if tensorflow is actually extracting and storing separate filter-sized 'patches' from the image under the hood, then a back-of-the-envelope calculation shows that the intermediate computation involved requires ~130GB in my case, well over the limit that I could test.

正如你自己想象的那樣,這是大容量內存消耗的原因。 Tensorflow這樣做是因爲濾波器通常很小,計算矩陣乘法比計算卷積要快得多。

can anyone explain why TF would do this when I'm still only debugging on a CPU?

