2017-02-03 86 views
1

說我有一個張量:如何在張量流中刪除四維張量中的零點?

import tensorflow as tf 
t = tf.Variable([[[[0., 235., 0., 0., 1006., 0., 0., 23., 42.], [77., 0., 0., 12., 0., 0., 33., 55., 0.]], 
       [[0., 132., 0., 0., 234., 0., 1., 24., 0.], [43., 0., 0., 124., 0., 0., 0., 52., 645]]]]) 

我想省略零和會留下形狀的張量:(1,2,2,4),4爲在非零種元素的數目我張量像

t = tf.Variable([[[[235., 1006., 23., 42], [77., 12., 33., 55.]], 
       [[132., 234., 1., 24.], [43., 124., 52., 645]]]]) 

我已經使用布爾掩模來做到這一點的一維張量。我怎樣才能在4-D張量中省略零。它可以推廣到更高級別嗎?

回答

2

使用TensorFlow 0.12.1:

import tensorflow as tf 

def batch_of_vectors_nonzero_entries(batch_of_vectors): 
    """Removes non-zero entries from batched vectors. 

    Requires that each vector have the same number of non-zero entries. 

    Args: 
    batch_of_vectors: A Tensor with length-N vectors, having shape [..., N]. 
    Returns: 
    A Tensor with shape [..., M] where M is the number of non-zero entries in 
    each vector. 
    """ 
    nonzero_indices = tf.where(tf.not_equal(
     batch_of_vectors, tf.zeros_like(batch_of_vectors))) 
    # gather_nd gives us a vector containing the non-zero entries of the 
    # original Tensor 
    nonzero_values = tf.gather_nd(batch_of_vectors, nonzero_indices) 
    # Next, reshape so that all but the last dimension is the same as the input 
    # Tensor. Note that this will fail unless each vector has the same number of 
    # non-zero values. 
    reshaped_nonzero_values = tf.reshape(
     nonzero_values, 
     tf.concat(0, [tf.shape(batch_of_vectors)[:-1], [-1]])) 
    return reshaped_nonzero_values 

t = tf.Variable(
    [[[[0., 235., 0., 0., 1006., 0., 0., 23., 42.], 
     [77., 0., 0., 12., 0., 0., 33., 55., 0.]], 
     [[0., 132., 0., 0., 234., 0., 1., 24., 0.], 
     [43., 0., 0., 124., 0., 0., 0., 52., 645]]]]) 
nonzero_t = batch_of_vectors_nonzero_entries(t) 

with tf.Session(): 
    tf.global_variables_initializer().run() 
    result_evaled = nonzero_t.eval() 
    print(result_evaled.shape, result_evaled) 

打印:

(1, 2, 2, 4) [[[[ 2.35000000e+02 1.00600000e+03 2.30000000e+01 4.20000000e+01] 
    [ 7.70000000e+01 1.20000000e+01 3.30000000e+01 5.50000000e+01]] 

    [[ 1.32000000e+02 2.34000000e+02 1.00000000e+00 2.40000000e+01] 
    [ 4.30000000e+01 1.24000000e+02 5.20000000e+01 6.45000000e+02]]]] 

可能尋找到SparseTensor■如果結果永遠結束是破爛有用的。

相關問題