2017-07-07 35 views
0

我不太清楚廣播機制在Tensorflow中的工作原理。假設我們有以下代碼:Tensorflow:奇怪的廣播行爲

W1_shape = [5, 5, 1, 32] 
b1_shape = [32] 
x = tf.placeholder(tf.float32) 
initial_W1 = tf.truncated_normal(shape=W1_shape, stddev=0.1) 
W1 = tf.Variable(initial_W1) 
initial_b1 = tf.constant(0.1, shape=b1_shape) 
b1 = tf.Variable(initial_b1) 
conv1 = tf.nn.conv2d(x, W1, strides=[1, 1, 1, 1], padding='SAME') 
conv1_sum = conv1 + b1 
y = tf.placeholder(tf.float32) 
z = conv1 + y 

sess = tf.Session() 
# Run init ops 
init = tf.global_variables_initializer() 
sess.run(init) 

while True: 
    samples, labels, indices = dataset.get_next_batch(batch_size=1000) 
    samples = samples.reshape((1000, MnistDataSet.MNIST_SIZE, MnistDataSet.MNIST_SIZE, 1)) 
    y_data = np.ones(shape=(1000, 32)) 
    conv1_res, conv1_sum_res, b1_res, z_res=\ 
    sess.run([conv1, conv1_sum, b1, z], feed_dict={x: samples, y: y_data}) 

if dataset.isNewEpoch: 
    break 

因此,我加載了MNIST數據集,它由28x28大小的圖像組成。卷積運算符使用32個5x5大小的濾波器。我使用1000的批量大小,因此數據張量x具有形狀(1000,28,28,1)。 tf.nn.conv2d操作輸出形狀的張量(1000,28,28,32)。 y是一個佔位符,我添加了一個變量來檢查Tensorflow的廣播機制,將它添加到(1000,28,28,32)形conv1張量中。在y_data = np.ones(shape=(1000, 32))行中,我嘗試了各種張量形狀爲y。的形狀(28,28),(1000,28)和(1000,32)不會增加conv1,與該類型的誤差:

InvalidArgumentError(參見上述用於回溯):不兼容的形狀: [1000,28,28,32]對比[28,28]

形狀(28,32)和(28,28,32)正常工作和廣播。但根據https://www.tensorflow.org/performance/xla/broadcasting中解釋的廣播語義,前三種形狀也必須起作用,因爲它們通過將尺寸與4D conv1張量進行匹配而具有正確的順序。例如,維度1和2中的(28,28)匹配(1000,28,28,32),(1000,32)匹配維度0和3,正如鏈接中所述。我在這裏錯過或誤解了一些東西嗎? Tensorflow在這種情況下的正確廣播行爲是什麼?

回答

1

其確實的文件似乎是建議你說什麼。但它看起來像它遵循numpy broadcsting rules

當兩個數組操作,NumPy的比較它們的形狀 逐元素。它以尾隨尺寸開始,並且向前推進。兩個維度是兼容時:

  1. 它們相等,或
  2. 其中
  3. 一個是1

因此,通過上述定義:

  • (28,28)不能廣播到(1000,28,28,32),但(28,28,1)可以。
  • (1000,28)不能不(1000,1,28,1)或(1000,28,1,1)可以

  • (28,32)工作,因爲該後尺寸匹配。

+0

是的,它確實看起來像這樣。這些文件在廣播上是不正確的。 –