兩個電流的答案有點通過過濾使用「動力」字符串變量名工作。但是,這是非常脆弱的雙方:
- 它可以默默(重新)初始化一些其他變量,你實際上不想重置!無論是因爲名稱衝突,還是因爲您有更復雜的圖形並分別優化不同的零件,例如。
- 它只適用於一個特定的優化器,您如何知道名稱以尋找其他人?
- 獎勵:張量流的更新可能是默默打破你的代碼。
幸運的是,tensorflow的抽象Optimizer
類有一個機制,這些額外的優化變量被稱爲"slots",並且可以使用get_slot_names()
方法得到一個優化的所有插槽名稱:
opt = tf.train.MomentumOptimizer(...)
print(opt.get_slot_names())
# prints ['momentum']
你
opt.get_slot(some_var, 'momentum')
:可以使用
get_slot(var, slot_name)
方法獲得對應於該時隙中的變量爲特定(可訓練)可變
把所有這些組合起來,你可以創建初始化優化的狀態如下運算:
var_list = # list of vars to optimize, e.g.
# tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
opt = tf.train.MomentumOptimizer(0.1, 0.95)
step_op = opt.minimize(loss, var_list=var_list)
reset_opt_op = tf.variables_initializer([opt.get_slot(var, name) for name in opt.get_slot_names() for var in var_list])
這真的只有重新設置正確的變量,並橫跨優化穩健。
不包括一個unfortunate caveat:AdamOptimizer
。那個人也會保留一個被叫的頻率。這意味着你應該真的認真思考你在這裏做什麼,但爲了完整起見,你可以得到其額外的狀態,如opt._get_beta_accumulators()
。返回的列表應添加到上面的reset_opt_op
行中的列表中。
它應該是:var_list = [在tf.global_variables VAR爲VAR()如果var.name '勢頭'] –
@MichaelPresečan固定的,謝謝! –