2017-01-28 76 views
0

我工作在網絡上與4類特定的自動編碼(3層前饋),以及訓練迭代中,存在這樣一種情況檢查來決定,它的自動編碼必須更新:TensorFlow:tf.case()f_default應該什麼都不做

def f(k): return tf.train.AdamOptimizer(learning_rate=lernrate).minimize(Cost_List[k]), n_List[k].assign_add(1.0), Cost_List[k] 

def g(): ??? 

nothing = g() 



min_index = tf.argmin(Cost_List, 0) 

Case_0 = (tf.equal(min_index,0), lambda: f(0)) 
Case_1 = (tf.equal(min_index,1), lambda: f(1)) 
Case_2 = (tf.equal(min_index,2), lambda: f(2)) 
Case_3 = (tf.equal(min_index,3), lambda: f(3)) 


Case_List = [Case_0, Case_1, Case_2, Case_3] 

[optimizer, update, cost] = tf.case(Case_List, nothing) 

在這種情況下,沒有條件滿足,什麼都不應該做。在這種情況下,四種情況中的一種會被實現,所以這裏還沒有實際問題,但我想修改代碼,然後重要的是,訓練樣本將在默認情況下跳過。問題是,f_default和所有其他返回類型的返回類型必須相同,因爲sess.run([optimizer,update,cost])需要某種類型。我該如何做到這一點,在默認情況下確實沒有任何事情發生?我已經嘗試使用tf.no_op(),但不工作...

感謝,

Meridius

回答

1

要簽名匹配,你可以定義g()如下:

def g(): 
    return tf.no_op(), tf.no_op(), tf.constant(0.0) 

請注意,直接通過g作爲f_default(而不是像當前代碼那樣傳遞g())會更有效率,但行爲應該相同。