2017-10-15 21 views
1

如果你有一些昂貴的操作的條件,你可能想要懶惰的行爲,即只評估選擇的分支。在Tensorflow中做懶惰條件

下面的作品,是懶惰:

>>> a. tf.zeros(0) 
>>> tf.cond(tf.equal(tf.size(a), tf.constant(0)), lambda: tf.constant(-1, dtype=tf.int64), lambda: tf.argmax(a)).eval() 
-1 

你可以看到,這是懶惰的,因爲argmax沒有評估,因爲它會導致錯誤。因爲argmax取的張量是空的。如果移動argmax了拉姆達的,它產生這種非常錯誤:

>>> am = tf.argmax(a) 
>>> tf.cond(tf.equal(tf.size(a), tf.constant(0)), lambda: tf.constant(-1, dtype=tf.int64), lambda: tf.add(am, 1)).eval() 
... Reduction axis 0 is empty in shape [0] 

哪一個不是由tf.add操作引起的。移動它內聯,它再次運作:

>>> tf.cond(tf.equal(tf.size(a), tf.constant(0)), lambda: tf.constant(-1, dtype=tf.int64), lambda: tf.add(tf.argmax(a), 1)).eval() 
-1 

然後,問題是如何做一個更清潔的方式做懶惰條件?

回答

1

當條件函數變長時,上述方法會變得有點混亂。你可以做的是在條件之外定義一個lambda表達式。 請注意,以下內容在Python交互式REPL中不起作用,其結果爲ValueError: Operation 'cond_14/Merge' has been marked as not fetchable.

它可以工作當你把代碼放入python文件並以正常方式運行時。

import tensorflow as tf 

sess = tf.InteractiveSession() 

a = tf.zeros(0) 
fn = lambda: tf.argmax(a) 

res = tf.cond(
    tf.equal(tf.size(a), tf.constant(0)), 
    lambda: tf.constant(-1, dtype=tf.int64), 
    fn 
    ).eval() 
print(res) 

res2 = tf.cond(
    tf.equal(tf.size(a), tf.constant(0)), 
    lambda: tf.constant(-1, dtype=tf.int64), 
    lambda: tf.add(fn(), tf.constant(1, dtype=tf.int64)) 
    ).eval() 
print(res2) 
# Output: 
# -1 
# -1