我有兩個張量x
和s
與形狀:廣播兩者之間的同階張量
> x.shape
TensorShape([Dimension(None), Dimension(3), Dimension(5), Dimension(5)])
> s.shape
TensorShape([Dimension(None), Dimension(12), Dimension(5), Dimension(5)])
我想通過尺寸1
廣播x
和s
之間的點積如下:
> x_s.shape
TensorShape([Dimension(None), Dimension(4), Dimension(5), Dimension(5)])
其中
x_s[i, 0, k, l] = sum([x[i, j, k, l] * s[i, j, k, l] for j in range (3)])
x_s[i, 1, k, l] = sum([x[i, j-3, k, l] * s[i, j, k, l] for j in range (3, 6)])
x_s[i, 2, k, l] = sum([x[i, j-6, k, l] * s[i, j, k, l] for j in range (6, 9)])
x_s[i, 3, k, l] = sum([x[i, j-9, k, l] * s[i, j, k, l] for j in range (9, 12)])
我有這樣實現:
s_t = tf.transpose(s, [0, 2, 3, 1]) # [None, 5, 5, 12]
x_t = tf.transpose(x, [0, 2, 3, 1]) # [None, 5, 5, 3]
x_t = tf.tile(x_t, [1, 1, 1, 4]) # [None, 5, 5, 12]
x_s = x_t * s_t # [None, 5, 5, 12]
x_s = tf.reshape(x_s, [tf.shape(x_s)[0], 5, 5, 4, 3]) # [None, 5, 5, 4, 3]
x_s = tf.reduce_sum(x_s, axis=-1) # [None, 5, 5, 4]
x_s = tf.transpose(x_s, [0, 3, 1, 2]) # [None, 4, 5, 5]
我明白這不是因爲tile
的內存使用效率。另外,reshape
's,transpose
的element-wise
和reduce_sum
的操作可能會損害較大張量的性能。有沒有其他辦法可以使它更清潔?
感謝您的幫助。事實上,你的impl('x_s2')比OP('x_s')要快,因爲每個循環的最佳時間爲3:365μs,所以你的impl('x_s')的速度會更快。%timeit sess.run(x_s2)> timeit sess.run(x_s) > ) 1000個循環,最好是3:每個循環243μs – John