2017-02-09 100 views
1

利用TensorFlow的HashTable查找實現,我使用提供的默認值返回SparseTensor。我想清除它並得到一個沒有默認值的最終SparseTensor。如何從稀疏張量中獲取非零值

如何清除該默認值?我不介意爲了實現這個目的,默認值是什麼。 0很好,所以是-1。

回答

0

tf.sparse_retain應該工作:

def sparse_remove(sparse_tensor, remove_value=0.): 
    return tf.sparse_retain(sparse_tensor, tf.not_equal(a.values, remove_value)) 

舉個例子:

import tensorflow as tf 

a = tf.SparseTensor(indices=[[1, 2], [2, 2]], values=[0., 1.], shape=[3, 3]) 
with tf.Session() as session: 
    print(session.run([a, sparse_remove(a)])) 

打印(我稍微重新格式化吧):

[SparseTensorValue(indices=array([[1, 2], [2, 2]]), values=array([ 0., 1.], dtype=float32), shape=array([3, 3])), 
SparseTensorValue(indices=array([[2, 2]]), values=array([ 1.], dtype=float32), shape=array([3, 3]))] 
相關問題