2017-09-05 65 views
1

我已經以下代碼:TensorFlow:如何獲得子陣列的每一行中張量

import numpy as np 
import tensorflow as tf 

series = tf.placeholder(tf.float32, shape=[None, 5]) 
series_length = tf.placeholder(tf.int32, shape=[None]) 
useful_series = tf.magic_slice_function(series, series_length) 

with tf.Session() as sess: 
    input_x = np.array([[1, 2, 3, 0, 0], 
         [2, 3, 0, 0, 0], 
         [1, 0, 0, 0, 0]]) 
    input_y = np.array([[3], [2], [1]]) 
    print(sess.run(useful_series, feed_dict={series: input_x, series_length: input_y})) 

預期輸出如下

[[1,2,3],[2,3], [1]]

我試過了幾個函數,例如tf.gather,tf.slice。他們都不工作。 什麼是magic_slice_function

+1

也許,你需要做的這個Tensorflow之外,因爲你想獲得什麼不是張量。 –

回答

1

這是一個有點棘手:

import numpy as np 
import tensorflow as tf 

series = tf.placeholder(tf.float32, shape=[None, 5]) 
series_length = tf.placeholder(tf.int64) 

def magic_slice_function(input_x, input_y): 
    array = [] 
    for i in range(len(input_x)): 
     temp = [input_x[i][j] for j in range(input_y[i])] 
     array.extend(temp) 
    return [array] 

with tf.Session() as sess: 
    input_x = np.array([[1, 2, 3, 0, 0], 
         [2, 3, 0, 0, 0], 
         [1, 0, 0, 0, 0]]) 

    input_y = np.array([3, 2, 1], dtype=np.int64) 

    merged_series = tf.py_func(magic_slice_function, [series, series_length], tf.float32, name='slice_func') 

    out = tf.split(merged_series, input_y) 
    print(sess.run(out, feed_dict={series: input_x, series_length: input_y})) 

輸出將是:

[array([ 1., 2., 3.], dtype=float32), array([ 2., 3.], dtype=float32), array([ 1.], dtype=float32)]