2017-01-23 24 views

回答

2

據我所知,你不能像一個像NumPy這樣的更高級的庫那樣在一個命令中做到這一點。 如果你真的想使用TF功能,我可以建議一些想:

x = tf.Variable([ 
    [1,2,3,1], 
    [0,0,0,0], 
    [1,3,5,7], 
    [0,0,0,0], 
    [3,5,7,8]]) 

y = tf.Variable([0,0,0,0]) 
condition = tf.equal(x, y) 
indices = tf.where(condition) 

這將導致如下:

[[1 0] 
[1 1] 
[1 2] 
[1 3] 
[3 0] 
[3 1] 
[3 2] 
[3 3]] 

或者你可以使用下面的,如果你只是想只零線:

row_wise_sum = tf.reduce_sum(tf.abs(x),1) 
select_zero_sum = tf.where(tf.equal(row_wise_sum,0)) 

with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
    print(sess.run(select_zero_sum)) 

其結果是:

[[1] 
[3]] 
0

它也可以以更簡單的方式完成:

g = tf.Graph() 
    with g.as_default(): 
     a = tf.placeholder(dtype=tf.float32, shape=[3, 4]) 
     b = tf.placeholder(dtype=tf.float32, shape=[1, 4]) 

     res = tf.not_equal(a, b) 
     res = tf.reduce_sum(tf.cast(res, tf.float32), 1) 
     res = tf.where(tf.equal(res1, [0.0]))[0] 

    with tf.Session(graph=g) as sess: 
    sess.run(tf.global_variables_initializer()) 
    dict_ = { 
     a:np.array([[2.0,6.0,3.0,2.0], 
        [1.0,8.0,32.0,1.0], 
        [1.0,8.0,3.0,11.0]]), 
     b:np.array([[1.0,8.0,3.0,11.0]]) 
    } 

    print(sess.run(res, feed_dict=dict_)) 
    :[2]