我有3張
X
形狀(1, c, h, w)
,假設(1, 20, 40, 50)
Fx
形狀(num, w, N)
,假設(1000, 50, 10)
Fy
形狀(num, N, h)
,假設(1000, 10, 40)
MATMUL不同等級
我想要做的就是Fy * (X * Fx)
(*
意味着matmul
)
X * Fx
形狀(num, c, h, N)
,假設(1000, 20, 40, 10)
Fy * (X * Fx)
形狀(num, c, N, N)
,假設(1000, 20, 10, 10)
我使用tf.tile
和tf.expand_dims
做
但我認爲它使用了大量的內存(tile
複製數據吧?),並緩慢
試圖找到更好的辦法那速度更快,佔用內存小,以實現
# X: (1, c, h, w)
# Fx: (num, w, N)
# Fy: (num, N, h)
X = tf.tile(X, [tf.shape(Fx)[0], 1, 1, 1]) # (num, c, h, w)
Fx_ex = tf.expand_dims(Fx, axis=1) # (num, 1, w, N)
Fx_ex = tf.tile(Fx_ex, [1, c, 1, 1]) # (num, c, w, N)
tmp = tf.matmul(X, Fxt_ex) # (num, c, h, N)
Fy_ex = tf.expand_dims(Fy, axis=1) # (num, 1, N, h)
Fy_ex = tf.tile(Fy_ex, [1, c, 1, 1]) # (num, c, N, h)
res = tf.matmul(Fy_ex, tmp) # (num, c, N, N)
是的,我從來沒有見過這個,有點難以理解,想了解這個想法 – xxi
awwwwwesome,它巨大的速度提高,非常感謝 – xxi