2016-06-19 24 views
4

我基本上有一批形狀爲[batch_size, layer_size]的張量A中的一層神經元激活。讓B = tf.square(A)。現在我要計算以下條件:批次中每個向量中的每個元素:if abs(e) < 1: e ← 0 else e ← B(e)其中eB中與e位於相同位置的元素。我能以某種方式通過單一的tf.cond操作矢量化整個操作嗎?如何計算TensorFlow中批量元素的條件?

回答

9

你可能想看看tf.where(condition, x, y)

對於您的問題:

A = tf.placeholder(tf.float32, [batch_size, layer_size]) 
B = tf.square(A) 

condition = tf.less(tf.abs(A), 1.) 

res = tf.where(condition, tf.zeros_like(B), B) 
+2

在Tensorflow 1.0,它現在'tf.where' https://www.tensorflow.org/versions/主/ api_docs /蟒/ control_flow_ops/comparison_operators#,其中 –