2017-01-08 118 views
11

我想在Tensorflow中使用MomentumOptimizer。然而,由於這種優化使用一些內部變量,嘗試使用它沒有初始化這個變量會產生錯誤:如何初始化Tensorflow中的優化器變量?

FailedPreconditionError (see above for traceback): Attempting to use uninitialized value Variable_2/Momentum

這可以通過初始化所有的變量很容易解決,例如:

tf.global_variables_initializer().run()

但是,我不想初始化全部的變量 - 只有那些Optimizer。有沒有辦法做到這一點?

回答

9

您可以按名稱篩選變量並僅初始化這些變量。 IE

momentum_initializers = [var.initializer for var in tf.global_variables() if 'Momentum' in var.name] 
sess.run(momentum_initializers) 
2

tf.variables_initializer似乎是初始化一組特定的變量的首選方式:

var_list = [var for var in tf.global_variables() if 'Momentum' in var.name] 
var_list_init = tf.variables_initializer(var_list) 
... 
sess = tf.Session() 
sess.run(var_list_init) 
+0

它應該是:var_list = [在tf.global_variables VAR爲VAR()如果var.name '勢頭'] –

+1

@MichaelPresečan固定的,謝謝! –

10

兩個電流的答案有點通過過濾使用「動力」字符串變量名工作。但是,這是非常脆弱的雙方:

  1. 它可以默默(重新)初始化一些其他變量,你實際上不想重置!無論是因爲名稱衝突,還是因爲您有更復雜的圖形並分別優化不同的零件,例如。
  2. 它只適用於一個特定的優化器,您如何知道名稱以尋找其他人?
  3. 獎勵:張量流的更新可能是默默打破你的代碼。

幸運的是,tensorflow的抽象Optimizer類有一個機制,這些額外的優化變量被稱爲"slots",並且可以使用get_slot_names()方法得到一個優化的所有插槽名稱:

opt = tf.train.MomentumOptimizer(...) 
print(opt.get_slot_names()) 
# prints ['momentum'] 

opt.get_slot(some_var, 'momentum') 
:可以使用 get_slot(var, slot_name)方法獲得對應於該時隙中的變量爲特定(可訓練)可變

把所有這些組合起來,你可以創建初始化優化的狀態如下運算:

var_list = # list of vars to optimize, e.g. 
      # tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 
opt = tf.train.MomentumOptimizer(0.1, 0.95) 
step_op = opt.minimize(loss, var_list=var_list) 
reset_opt_op = tf.variables_initializer([opt.get_slot(var, name) for name in opt.get_slot_names() for var in var_list]) 

這真的只有重新設置正確的變量,並橫跨優化穩健。

不包括一個unfortunate caveatAdamOptimizer。那個人也會保留一個被叫的頻率。這意味着你應該真的認真思考你在這裏做什麼,但爲了完整起見,你可以得到其額外的狀態,如opt._get_beta_accumulators()。返回的列表應添加到上面的reset_opt_op行中的列表中。

2

大廈關閉LucasB的約AdamOptimizer答案,這個功能需要一個AdamOptimizer實例adam_opt有其Variables創建(這兩個一個叫:adam_opt.minimize(loss, var_list=var_list)adam_opt.apply_gradients(zip(grads, var_list))該函數創建一個Op,調用它時,重新初始化優化的。變量傳遞的變量,以及全球計數狀態

def adam_variables_initializer(adam_opt, var_list): 
    adam_vars = [adam_opt.get_slot(var, name) 
       for name in adam_opt.get_slot_names() 
       for var in var_list if var is not None] 
    adam_vars.extend(list(adam_opt._get_beta_accumulators())) 
    return tf.variables_initializer(adam_vars) 

如:

opt = tf.train.AdamOptimizer(learning_rate=1e-4) 
fit_op = opt.minimize(loss, var_list=var_list) 
reset_opt_vars = adam_variables_initializer(opt, var_list) 
+1

在我的情況下,adam_vars列表可能包含無類型的變量,不知道是否有一個優雅的方式來解決它...目前我只是過濾它們全部 –

+0

@TamakiSakura嗯哪些?我用列表理解中的過濾器更新了答案 – eqzx

+0

''[adam_opt.get_slot(var,name)for name in adam_opt.get_slot_names()for var in var_list]'''part,我確定我的var_list不包含沒有。我目前做的事情非常難看:'''在調用''''''''''''''''''''''''''''''''''''''''''''''''之前,'''adam_vars = filter(lambda x:x不是None,adam_vars) –

0

要解決該問題無剛做:

self.opt_vars = [opt.get_slot(var, name) for name in opt.get_slot_names() 
        for var in self.vars_to_train 
        if opt.get_slot(var, name) is not None]