2017-09-15 58 views
1

我想在「無」維度中切割張量。如何在Tensorflow中切割無維度的張量

例如,

tensor = tf.placeholder(tf.float32, shape=[None, None, 10], name="seq_holder") 
sliced_tensor = tensor[:,1:,:] # it works well! 

# Assume that tensor's shape will be [3,10, 10] 
tensor = tf.placeholder(tf.float32, shape=[None, None, 10], name="seq_holder") 
sliced_seq = tf.slice(tensor, [0,1,0],[3, 9, 10]) # it doens't work! 

它是相同,我得到一個消息時,我使用的另一place_holder喂尺寸參數tf.slice()。

第二種方法給了我「輸入大小(輸入深度)必須可通過形狀推斷訪問」錯誤消息。

我想知道兩種方法之間有什麼不同,什麼是更多的張量流法。

將帖子 整個代碼是下面

import tensorflow as tf 
import numpy as np 

print("Tensorflow for tests!") 

vec_dim = 5 
num_hidden = 10 
# method 1 
input_seq1 = np.random.random([3,7,vec_dim]) 

# method 2 
input_seq2 = np.random.random([5,10,vec_dim]) 
shape_seq2 = [5,9,vec_dim] 
# seq: [batch, seq_len] 
seq = tf.placeholder(tf.float32, shape=[None, None, vec_dim], name="seq_holder") 

# Method 1 
sliced_seq = seq[:,1:,:] 

# Method 2 
seq_shape = tf.placeholder(tf.int32, shape=[3]) 
sliced_seq = tf.slice(seq,[0,0,0], seq_shape) 

cell = tf.contrib.rnn.GRUCell(num_units=num_hidden) 
init_state = cell.zero_state(tf.shape(seq)[0], tf.float32) 

outputs, last_state = tf.nn.dynamic_rnn(cell, sliced_seq, initial_state=init_state) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    # method 1 
    # states = sess.run([sliced_seq], feed_dict={seq:input_seq1}) 
    # print(states[0].shape) 

    # method 2 
    states = sess.run([sliced_seq], feed_dict={seq:input_seq2, seq_shape:shape_seq2}) 
    print(states[0].shape) 
+0

當你定義操作(當運行這兩行時)或者當你試圖執行它們(例如在一個會話中調用'.run')時,你會得到錯誤嗎?運行這兩條指令毫無困難。 – jdehesa

+0

當我將切片張量(第二種方法)放入動態rnn函數時出現錯誤。這是dynamic_rnn的問題嗎? –

+0

嗯,很難分辨...你可以編輯你的問題,並添加更多的細節,也許一個完整的最小可重複的例子,你的問題出現? – jdehesa

回答

1

您的問題是精確地通過issue #4590

的問題是,tf.nn.dynamic_rnn需要知道在輸入的最後一個維度的大小(所述的「深度」)。不幸的是,正如問題所指出的那樣,如果任何切片範圍在圖構建時並不完全已知,則當前tf.slice不能推斷出任何輸出大小;因此,sliced_seq最終形狀爲(?, ?, ?)

就你而言,第一個問題是你使用三個元素的佔位符來確定切片的大小;這不是最好的方法,因爲最後一個維度不應該改變(即使你稍後通過vec_dim,它可能會導致錯誤)。最簡單的解決辦法是把seq_shape成大小2(或者甚至是兩個獨立的佔位符)的佔位符,然後做像切片:

sliced_seq = seq[:seq_shape[0], :seq_shape[1], :] 

出於某種原因,NumPy的風格索引似乎有更好的狀態推理能力,這將保留sliced_seq中最後一個維度的大小。

+0

我明白了:)感謝您的詳細解釋! –