2017-06-06 76 views
-1

我有三個數組,X,YZ。如果Z的對應元素爲真,我想要放入resX的元素;否則,我會放入一個來自Y的元素。其中()需要1到2個位置參數,但有3個被給出

我實現這樣的:

X = tf.constant([[1, 2], [3, 4]]) 
Y = tf.constant([[5, 6], [7, 8]]) 
Z = tf.constant([[True, False], [False, True]], tf.bool) 
res = tf.where(Z, X, Y) 
print(res.eval()) 

不過,我得到這個錯誤:

TypeError: where() takes from 1 to 2 positional arguments but 3 were given 

我看着tf.where的definiton從here和我的使用似乎罰款。

任何想法可能是什麼問題?

+0

你可以試試'tf.where(Z,X = X,Y = Y)' – pramod

+0

您的代碼工作正常TensorFlow 1.0.1,所以我很好奇:這你使用TF版本? – npf

回答

相關問題