2016-06-23 54 views
2

如何編寫分段TensorFlow函數,即其中包含if語句的函數?在TensorFlow中寫入分段函數/如果在TensorFlow中寫入

當前代碼

import tensorflow as tf 

my_fn = lambda x : x ** 2 if x > 0 else x + 5 
with tf.Session() as sess: 
    x = tf.Variable(tf.random_normal([100, 1])) 
    output = tf.map_fn(my_fn, x) 

錯誤:

類型錯誤:使用tf.Tensor爲Python bool是不允許的。使用if t is not None:而不是if t:來測試張量是否已定義,並使用邏輯TensorFlow操作來測試張量的值。

回答

3

你應該看看tf.select

對於你的榜樣,你可以這樣做:

condition = tf.greater(x, 0) 
res = tf.select(condition, tf.square(x), x + 5) 
1

這裏的問題是,my_fn有檢查條件x>0沒有辦法,因爲xtf.Tensor,這意味着,這將是充滿值僅如果張量流會話已啓動,並且您請求運行包含x的圖的一部分。要在圖中包含if-then邏輯,你必須使用tensorflow提供的操作,例如tf.select

0

一個最簡單的方式實現這一目標使用tf.cond

res = tf.cond(tf.greater(x, 0), lambda: tf.square(x), lambda: x + 5)