2017-03-08 79 views
4

我有3張
X形狀(1, c, h, w),假設(1, 20, 40, 50)
Fx形狀(num, w, N),假設(1000, 50, 10)
Fy形狀(num, N, h),假設(1000, 10, 40)MATMUL不同等級

我想要做的就是Fy * (X * Fx)*意味着matmul
X * Fx形狀(num, c, h, N),假設(1000, 20, 40, 10)
Fy * (X * Fx)形狀(num, c, N, N),假設(1000, 20, 10, 10)

我使用tf.tiletf.expand_dims
但我認爲它使用了大量的內存(tile複製數據吧?),並緩慢
試圖找到更好的辦法那速度更快,佔用內存小,以實現

# X: (1, c, h, w) 
# Fx: (num, w, N) 
# Fy: (num, N, h) 

X = tf.tile(X, [tf.shape(Fx)[0], 1, 1, 1]) # (num, c, h, w) 
Fx_ex = tf.expand_dims(Fx, axis=1) # (num, 1, w, N) 
Fx_ex = tf.tile(Fx_ex, [1, c, 1, 1]) # (num, c, w, N) 
tmp = tf.matmul(X, Fxt_ex) # (num, c, h, N) 

Fy_ex = tf.expand_dims(Fy, axis=1) # (num, 1, N, h) 
Fy_ex = tf.tile(Fy_ex, [1, c, 1, 1]) # (num, c, N, h) 
res = tf.matmul(Fy_ex, tmp) # (num, c, N, N) 

回答

2

的情況爲mythical einsum,我想:

>>> import numpy as np 
>>> X = np.random.rand(1, 20, 40, 50) 
>>> Fx = np.random.rand(100, 50, 10) 
>>> Fy = np.random.rand(100, 10, 40) 
>>> np.einsum('nMh,uchw,nwN->ncMN', Fy, X, Fx).shape 
(100, 20, 10, 10) 

它應該在tfnumpy幾乎相同(在一些tf版本中,我看到使用大寫索引是不允許的)。儘管如此,如果你以前從未見過符號,這肯定會超過一個不可讀的正則表達式。

+0

是的,我從來沒有見過這個,有點難以理解,想了解這個想法 – xxi

+0

awwwwwesome,它巨大的速度提高,非常感謝 – xxi

0

對於otherone可能感興趣
我覺得@phg的答案也許工作
但在我的情況下numhw是動態的,即None
所以tf.einsum在tensorflow R1.0會引發錯誤,因爲有更多的比一個張量

幸好一個None形狀,有一個issuepull request
似乎可以處理的情況,有一個以上的None SHA PE
需要從源(主分支)建立
後,我重新構建tensorflow

BTW我將報告結果,在tf.einsum只接受小寫

報告
是,最新版本張量流(主枝)接受動態形狀爲tf.einsum
而且使用後速度大幅提升tf.einsum,真是太棒了