2017-05-04 88 views
6

我正在使用Python API for Tensorflow。我想實現以下給出不使用Python的循環Rosenbrock function瞭解Tensorflow中的while循環

Rosenbrock function

我當前實現如下:

def rosenbrock(data_tensor): 
    columns = tf.unstack(data_tensor) 

    summation = 0 
    for i in range(1, len(columns) - 1): 
     first_term = tf.square(tf.subtract(columns[i + 1], tf.square(columns[i]))) 
     second_term = tf.square(tf.subtract(columns[i], 1.0)) 
     summation += tf.add(tf.multiply(100.0, first_term), second_term) 

    return summation 

我試圖實現在tf.while_loop()的總和;然而,我發現這個API在使用索引整數時有點不直觀,因爲索引整數意味着與數據保持獨立。在documentation給出的示例使用數據作爲索引(或反之亦然):

i = tf.constant(0) 
c = lambda i: tf.less(i, 10) 
b = lambda i: tf.add(i, 1) 
r = tf.while_loop(c, b, [i]) 
+0

僅僅使用for循環是否合適?使用while_loop有什麼好處?還是有必要? – lerner

+0

在上面的代碼中,for循環將執行python代碼。如果我們調用他的for循環「f」的主體,那麼你可以將python代碼視爲執行f,f,f,f,f,... f。所以它會將這個「主體」函數調用N次,並且函數的圖形將因此具有該函數N次。如果你使用tf.while_loop,那麼你只會在圖中看到該函數一次。 –

+0

tf.while_loop的優點是:1)可以並行運行迭代,2)可以在條件語句中使用運行時常量。例如,如果你想運行優化器直到滿足一定的容差,那麼你必須使用tf.while_loop變體,因爲python最初不能評估條件 –

回答

10

這可以通過使用可實現的tf.while_loop()和標準tuples按照在documentation第二示例。

def rosenbrock(data_tensor): 
    columns = tf.unstack(data_tensor) 

    # Track both the loop index and summation in a tuple in the form (index, summation) 
    index_summation = (tf.constant(1), tf.constant(0.0)) 

    # The loop condition, note the loop condition is 'i < n-1' 
    def condition(index, summation): 
     return tf.less(index, tf.subtract(tf.shape(columns)[0], 1)) 

    # The loop body, this will return a result tuple in the same form (index, summation) 
    def body(index, summation): 
     x_i = tf.gather(columns, index) 
     x_ip1 = tf.gather(columns, tf.add(index, 1)) 

     first_term = tf.square(tf.subtract(x_ip1, tf.square(x_i))) 
     second_term = tf.square(tf.subtract(x_i, 1.0)) 
     summand = tf.add(tf.multiply(100.0, first_term), second_term) 

     return tf.add(index, 1), tf.add(summation, summand) 

    # We do not care about the index value here, return only the summation 
    return tf.while_loop(condition, body, index_summation)[1] 

重要的是要注意,索引增量應該出現在循環體中,類似於標準while循環。在給出的解決方案中,它是由body()函數返回的元組中的第一項。

此外,循環條件函數必須爲總和分配一個參數,雖然它在此特定示例中未使用。