2017-07-25 39 views
0

背景:Tensorflow - 處理輸入到圖表

我試圖通過尺寸M的另一張量我想要的結果到一個MXN張量

例如乘以大小爲m的張量:

1.0    
     0.2 0.2 0.5 
0.0 X  = 0.0 0.0 
     0.5 0.2 0.5 
1.0    

我可以用numpy的做到這一點:

x_vals = np.array([[1.0, 0.0, 1.0],[0.0, 0.0, 1.0], [0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) 
deltas = np.array([0.2, 0.5]) 

def Mult(x): 
    return x*deltas 

#I can do this... 
for x in x_vals: 
    print Mult(x.reshape(3,1)) 

我不能用tensorflow做到這一點?

x_vals = np.array([[1.0, 0.0, 1.0],[0.0, 0.0, 1.0], [0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) 
deltas = np.array([0.2, 0.5]) 

def Mult(x): 
    return x*deltas 
x = tf.placeholder('float', (None,3)) 
delta = tf.constant(deltas) 
result = Mult(tf.reshape(x, shape=(3,1))) 
init = tf.global_variables_initializer() 

# create session and run the graph 
with tf.Session() as sess: 
    sess.run(init) 
    arr = sess.run(result, feed_dict={x: x_vals}) 

看來,它通過對整個x_val陣列中,而我不知道循環在session.run每個條目是它應該如何工作的。任何人都可以給我一個指針?

回答

0

numpy它可以簡化爲:x_vals[:,:,np.newaxis]*deltas[np.newaxis,:]

這應該在tensorflow工作,以及:

x_vals = np.array([[1.0, 0.0, 1.0],[0.0, 0.0, 1.0], [0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) 
deltas = np.array([0.2, 0.5]) 

def Mult(x): 
    return x[:,:,tf.newaxis]*deltas[tf.newaxis,:] 
x = tf.placeholder('float', (None,3)) 
delta = tf.constant(deltas) 
result = Mult(x) 
init = tf.global_variables_initializer() 

# create session and run the graph 
with tf.Session() as sess: 
sess.run(init) 
arr = sess.run(result, feed_dict={x: x_vals}) 
+0

我明白了,你擴大了張量的尺寸。謝謝 – Kevin