我不太清楚廣播機制在Tensorflow中的工作原理。假設我們有以下代碼:Tensorflow:奇怪的廣播行爲
W1_shape = [5, 5, 1, 32]
b1_shape = [32]
x = tf.placeholder(tf.float32)
initial_W1 = tf.truncated_normal(shape=W1_shape, stddev=0.1)
W1 = tf.Variable(initial_W1)
initial_b1 = tf.constant(0.1, shape=b1_shape)
b1 = tf.Variable(initial_b1)
conv1 = tf.nn.conv2d(x, W1, strides=[1, 1, 1, 1], padding='SAME')
conv1_sum = conv1 + b1
y = tf.placeholder(tf.float32)
z = conv1 + y
sess = tf.Session()
# Run init ops
init = tf.global_variables_initializer()
sess.run(init)
while True:
samples, labels, indices = dataset.get_next_batch(batch_size=1000)
samples = samples.reshape((1000, MnistDataSet.MNIST_SIZE, MnistDataSet.MNIST_SIZE, 1))
y_data = np.ones(shape=(1000, 32))
conv1_res, conv1_sum_res, b1_res, z_res=\
sess.run([conv1, conv1_sum, b1, z], feed_dict={x: samples, y: y_data})
if dataset.isNewEpoch:
break
因此,我加載了MNIST數據集,它由28x28大小的圖像組成。卷積運算符使用32個5x5大小的濾波器。我使用1000的批量大小,因此數據張量x
具有形狀(1000,28,28,1)。 tf.nn.conv2d
操作輸出形狀的張量(1000,28,28,32)。 y
是一個佔位符,我添加了一個變量來檢查Tensorflow的廣播機制,將它添加到(1000,28,28,32)形conv1
張量中。在y_data = np.ones(shape=(1000, 32))
行中,我嘗試了各種張量形狀爲y
。的形狀(28,28),(1000,28)和(1000,32)不會增加conv1
,與該類型的誤差:
InvalidArgumentError(參見上述用於回溯):不兼容的形狀: [1000,28,28,32]對比[28,28]
形狀(28,32)和(28,28,32)正常工作和廣播。但根據https://www.tensorflow.org/performance/xla/broadcasting中解釋的廣播語義,前三種形狀也必須起作用,因爲它們通過將尺寸與4D conv1
張量進行匹配而具有正確的順序。例如,維度1和2中的(28,28)匹配(1000,28,28,32),(1000,32)匹配維度0和3,正如鏈接中所述。我在這裏錯過或誤解了一些東西嗎? Tensorflow在這種情況下的正確廣播行爲是什麼?
是的,它確實看起來像這樣。這些文件在廣播上是不正確的。 –