2017-01-15 39 views
0

在線文檔中說,moving_average和moving_variance都是model_variables,而tf.model_variables()返回local_variables類型的張量。這意味着,當我保存我的狀態時,model_variables不會被保存嗎?Tensorflow的batch_norm中的模型變量

我正在嘗試將批量歸一化應用於幾個3D卷積和完全連接的圖層。我用batch_norm訓練了我的網絡並保存了一個檢查點文件,但是當我去恢復我保存的狀態時,它說move_mean找不到。確切的錯誤是,當TF將恢復的值分配給moving_mean時,lhs張量的形狀[]不能與rhs的形狀一致,[20]。

當我不在圖層周圍添加batch_norm時,圖恢復正常。 我打算在訓練結束時添加一個全局變量,以節省我的move_mean和moving_variance值。這是TF爲我準備使用batch_norm的方式嗎?

謝謝!

回答

1

變量moving_mean和moving_variance不在我保存的聲明中,因爲我已將updates_collections設置爲默認值。由於我在運行圖層時從未包含控件依賴項,因此這些變量從未更新過。

的代碼,包括是:

from tensorflow.python import control_flow_ops 

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 
if update_ops: 
    updates = tf.tuple(update_ops) 
    total_loss = control_flow_ops.with_dependencies(updates, total_loss) 

或者設置

updates_collection=None 

爲就地更新。

有關更多信息,請參見the API descriptioncurrent github discussion