2017-08-23 79 views
1

tf.reduce_meantf.reduce_prod這樣的函數執行單元操作以減少沿軸的張量。我有一個張量R形狀(1000, 3, 3),一個3x3矩陣的列表。我想要做的是矩陣乘以它們,所以我仍然與一個單一的3x3矩陣。如果這是numpy我可以使用減少張量流中的矩陣陣列

np.linalg.multi_dot(R) 

我如何在tensorflow中做到這一點?

回答

4

您可以使用tf.scantf.scan(lambda a, b: tf.matmul(a, b), R)[-1]

import tensorflow as tf 
import numpy as np 

R = np.random.rand(10, 3, 3) 
R_reduced = np.linalg.multi_dot(R) 

R_reduced_t = tf.scan(lambda a, b: tf.matmul(a, b), R)[-1] 

with tf.Session() as sess: 
    R_reduced_val = sess.run(R_reduced_t) 
    diff = R_reduced_val - R_reduced 
    print(diff) 

此打印:

[[ -3.55271368e-15 0.00000000e+00 0.00000000e+00] 
[ 1.77635684e-15 0.00000000e+00 3.55271368e-15] 
[ -1.77635684e-15 3.55271368e-15 0.00000000e+00]]