我有三個數組,X
,Y
和Z
。如果Z
的對應元素爲真,我想要放入res
和X
的元素;否則,我會放入一個來自Y
的元素。其中()需要1到2個位置參數,但有3個被給出
我實現這樣的:
X = tf.constant([[1, 2], [3, 4]])
Y = tf.constant([[5, 6], [7, 8]])
Z = tf.constant([[True, False], [False, True]], tf.bool)
res = tf.where(Z, X, Y)
print(res.eval())
不過,我得到這個錯誤:
TypeError: where() takes from 1 to 2 positional arguments but 3 were given
我看着tf.where
的definiton從here和我的使用似乎罰款。
任何想法可能是什麼問題?
你可以試試'tf.where(Z,X = X,Y = Y)' – pramod
您的代碼工作正常TensorFlow 1.0.1,所以我很好奇:這你使用TF版本? – npf