我使用的是Tensorflow v1.1,我一直在試圖弄清楚如何使用我的EMA'ed權重進行推理,但不管我做什麼,我總是得到錯誤在檢查點找不到密鑰<變量名> Tensorflow
找不到:重點W/ExponentialMovingAverage檢查點未發現
即使當我循環並打印出所有的tf.global_variables
主要存在
這是一個可重複腳本重調整從Facenet's單元測試:
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
# Try to find values for W and b that compute y_data = W * x_data + b
# (We know that W should be 0.1 and b 0.3, but TensorFlow will
# figure that out for us.)
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b
# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
opt_op = optimizer.minimize(loss)
# Track the moving averages of all trainable variables.
ema = tf.train.ExponentialMovingAverage(decay=0.9999)
variables = tf.trainable_variables()
print(variables)
averages_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([opt_op]):
train_op = tf.group(averages_op)
# Before starting, initialize the variables. We will 'run' this first.
init = tf.global_variables_initializer()
saver = tf.train.Saver(tf.trainable_variables())
# Launch the graph.
sess = tf.Session()
sess.run(init)
# Fit the line.
for _ in range(201):
sess.run(train_op)
w_reference = sess.run('W/ExponentialMovingAverage:0')
b_reference = sess.run('b/ExponentialMovingAverage:0')
saver.save(sess, os.path.join("model_ex1"))
tf.reset_default_graph()
tf.train.import_meta_graph("model_ex1.meta")
sess = tf.Session()
print('------------------------------------------------------')
for var in tf.global_variables():
print('all variables: ' + var.op.name)
for var in tf.trainable_variables():
print('normal variable: ' + var.op.name)
for var in tf.moving_average_variables():
print('ema variable: ' + var.op.name)
print('------------------------------------------------------')
mode = 1
restore_vars = {}
if mode == 0:
ema = tf.train.ExponentialMovingAverage(1.0)
for var in tf.trainable_variables():
print('%s: %s' % (ema.average_name(var), var.op.name))
restore_vars[ema.average_name(var)] = var
elif mode == 1:
for var in tf.trainable_variables():
ema_name = var.op.name + '/ExponentialMovingAverage'
print('%s: %s' % (ema_name, var.op.name))
restore_vars[ema_name] = var
saver = tf.train.Saver(restore_vars, name='ema_restore')
saver.restore(sess, os.path.join("model_ex1")) # error happens here!
w_restored = sess.run('W:0')
b_restored = sess.run('b:0')
print(w_reference)
print(w_restored)
print(b_reference)
print(b_restored)
嗨,謝謝,我一定會看看,花了,也更新了我的問題! – YellowPillow
我想我明白你的錯誤可能來自哪裏。您只用可訓練變量初始化保存程序。嘗試使用默認構建的保護程序。移動平均變量不可訓練,所以不會在您的檢查點結束。 –
默認構建的保護程序是什麼意思? – YellowPillow