2017-07-19 37 views
2

我使用的是Tensorflow v1.1,我一直在試圖弄清楚如何使用我的EMA'ed權重進行推理,但不管我做什麼,我總是得到錯誤在檢查點找不到密鑰<變量名> Tensorflow

找不到:重點W/ExponentialMovingAverage檢查點未發現

即使當我循環並打印出所有的tf.global_variables主要存在

這是一個可重複腳本重調整從Facenet's單元測試:

import tensorflow as tf 
import numpy as np 


tf.reset_default_graph() 

# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3 
x_data = np.random.rand(100).astype(np.float32) 
y_data = x_data * 0.1 + 0.3 

# Try to find values for W and b that compute y_data = W * x_data + b 
# (We know that W should be 0.1 and b 0.3, but TensorFlow will 
# figure that out for us.) 
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W') 
b = tf.Variable(tf.zeros([1]), name='b') 
y = W * x_data + b 

# Minimize the mean squared errors. 
loss = tf.reduce_mean(tf.square(y - y_data)) 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
opt_op = optimizer.minimize(loss) 

# Track the moving averages of all trainable variables. 
ema = tf.train.ExponentialMovingAverage(decay=0.9999) 
variables = tf.trainable_variables() 
print(variables) 
averages_op = ema.apply(tf.trainable_variables()) 
with tf.control_dependencies([opt_op]): 
    train_op = tf.group(averages_op) 

# Before starting, initialize the variables. We will 'run' this first. 
init = tf.global_variables_initializer() 

saver = tf.train.Saver(tf.trainable_variables()) 

# Launch the graph. 
sess = tf.Session() 
sess.run(init) 

# Fit the line. 
for _ in range(201): 
    sess.run(train_op) 

w_reference = sess.run('W/ExponentialMovingAverage:0') 
b_reference = sess.run('b/ExponentialMovingAverage:0') 

saver.save(sess, os.path.join("model_ex1")) 

tf.reset_default_graph() 

tf.train.import_meta_graph("model_ex1.meta") 
sess = tf.Session() 

print('------------------------------------------------------') 
for var in tf.global_variables(): 
    print('all variables: ' + var.op.name) 
for var in tf.trainable_variables(): 
    print('normal variable: ' + var.op.name) 
for var in tf.moving_average_variables(): 
    print('ema variable: ' + var.op.name) 
print('------------------------------------------------------') 

mode = 1 
restore_vars = {} 
if mode == 0: 
    ema = tf.train.ExponentialMovingAverage(1.0) 
    for var in tf.trainable_variables(): 
     print('%s: %s' % (ema.average_name(var), var.op.name)) 
     restore_vars[ema.average_name(var)] = var 
elif mode == 1: 
    for var in tf.trainable_variables(): 
     ema_name = var.op.name + '/ExponentialMovingAverage' 
     print('%s: %s' % (ema_name, var.op.name)) 
     restore_vars[ema_name] = var 

saver = tf.train.Saver(restore_vars, name='ema_restore') 

saver.restore(sess, os.path.join("model_ex1")) # error happens here! 

w_restored = sess.run('W:0') 
b_restored = sess.run('b:0') 

print(w_reference) 
print(w_restored) 
print(b_reference) 
print(b_restored) 

回答

2

key not found in checkpoint錯誤意味着該變量在內存模型,但不是在磁盤上的序列化的檢查點文件存在。

您應該使用inspect_checkpoint tool來了解什麼張量被保存在檢查點中,以及爲什麼某些指數移動平均值不會在此處保存。

這不是從線應該引發錯誤

+0

嗨,謝謝,我一定會看看,花了,也更新了我的問題! – YellowPillow

+0

我想我明白你的錯誤可能來自哪裏。您只用可訓練變量初始化保存程序。嘗試使用默認構建的保護程序。移動平均變量不可訓練,所以不會在您的檢查點結束。 –

+0

默認構建的保護程序是什麼意思? – YellowPillow

0

我想補充到最好使用檢查點訓練有素的變量的方法您的攝製例子清楚。

請記住,保存程序var_list中的所有變量都應包含在您配置的檢查點中。您可以通過檢查那些在保護程序:

print(restore_vars) 

,並在檢查站這些變量是:在你的情況

vars_in_checkpoint = tf.train.list_variables(os.path.join("model_ex1")) 

如果restore_vars都包括在vars_in_checkpoint那麼它不會引發錯誤,否則首先初始化所有的變量:

all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) 
sess.run(tf.variables_initializer(all_variables)) 

的所有變量將被初始化是這樣的,或者未在檢查點,那麼你就可以篩選出不包含在該檢查站restore_vars這些變量(假設與ExponentialMovingAverage所有變量在他們的名字沒有在檢查站):

temp_saver = tf.train.Saver(
    var_list=[v for v in all_variables if "ExponentialMovingAverage" not in v.name]) 
ckpt_state = tf.train.get_checkpoint_state(os.path.join("model_ex1"), lastest_filename) 
print('Loading checkpoint %s' % ckpt_state.model_checkpoint_path) 
temp_saver.restore(sess, ckpt_state.model_checkpoint_path) 

這可以節省相比,訓練模型一段時間 從頭開始​​。 (在我的情況下,恢復的變量與一開始從頭開始的培訓相比沒有明顯的改善,因爲所有舊的優化器變量都被放棄了,但它可以顯着加速優化過程,我認爲,因爲它就像是預訓練一些變量)

無論如何,一些變量是有用的恢復像嵌入和一些圖層等

相關問題