0
考慮下面的代碼:切片張量與INT32形狀的int64標量
x = tf.Variable([1.0,2.0,3.0])
i = tf.Variable([1], dtype = tf.int64)
x[i]
tensorflow拋出錯誤,很顯然,因爲x的形狀的類型是從類型i的不同。我可以通過將我轉換爲int32來解決它,但是有其他方法嗎?例如,我可以改變x的形狀類型嗎?
考慮下面的代碼:切片張量與INT32形狀的int64標量
x = tf.Variable([1.0,2.0,3.0])
i = tf.Variable([1], dtype = tf.int64)
x[i]
tensorflow拋出錯誤,很顯然,因爲x的形狀的類型是從類型i的不同。我可以通過將我轉換爲int32來解決它,但是有其他方法嗎?例如,我可以改變x的形狀類型嗎?
據我所知,tensorflow不支持通過__getitem__
作爲numpy的切片。替代方法是使用tf.gather
:
x = tf.Variable([1.0,2.0,3.0])
i = tf.Variable([1], dtype = tf.int64)
tf.gather(x, i)
「我可以更改x形狀的類型嗎?」 ...這就是你通過將'i'投射到'int32'所做的事情。 'x [tf.cast(i,tf.int32)]'應該在不修改'i'的情況下做你想做的事情(儘管如果它們對於'int32'來說太大,你將在'i'中的值被截斷) – GPhilo