2017-07-11 44 views
1

在tensorflow中,我試圖使用tf.slice,但as its documentation states,它要求切片適合輸入數組。例如,如果您嘗試切割張量的前5個位置[1,2,3,4],它會崩潰。我想要獲得與python列表或numpy數組相同的功能,其中切片可以讓您獲得原始數組與您請求的切片的交集。例如,如果你要求[1,2,3,4]的位置2到6,你會得到[2,3,4]。使用tf.slice檢測越界切片,如numpy

我如何在tensorflow中做到這一點?

謝謝!

回答

3

你可以使用tensorflow的蟒蛇切片運算符,這是稍微強過tf.slice,特別是做一些綁定檢查的行爲類似於numpy

x = tf.zeros((10,)) 
y = tf.slice(x, [5], [15]) 
print(y.shape) 
# (15,) 
z = x[5:42] 
print(z.shape) 
# (5,) 
0

你可以使用tf.clip_by_value

def robust_slice(t, begin, size, name=None): 
    with tf.name_scope('robust_slice'): 
     new_size = tf.clip_by_value(size + begin, clip_value_min = 0, clip_value_max = t.shape()) - begin 
     return tf.slice(t, begin, new_size, name) 

(未經測試,但這樣的財產以後應該做的工作)

+0

我試過了,但如果部分形狀未知,它似乎不起作用。 – etal