0
我正在嘗試保存Tensorflow模型並重新使用它。Tensorflow加載保存的模型會產生非常高的成本值
爲了清楚地理解這個問題,我創建了一個包含10個元素的二進制數據集,並重復運行這10個元素的訓練,而我每100次迭代保存一次模型。 然後在同一組上運行測試。理想情況下,我希望測試在保存模型時能夠產生相同的成本。 不過可能我錯過的東西和加載訓練模型並沒有給出預期成本值:如果我加載模型第300次迭代過程中保存
Step 0, cost 1.10902
Step 100, cost 0.83170
Step 200, cost 0.00003
Step 300, cost 0.00000
現在:
def model(X, w1, w2, w3, w4, wo, p_keep_conv, p_keep_hidden):
l1 = tf.nn.relu(tf.nn.conv2d(X, w1, strides=[1, 1, 1, 1], padding='SAME'))
l1 = tf.nn.max_pool(l1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
l1 = tf.nn.dropout(l1, p_keep_conv)
# ... other layer def.s
l4 = tf.nn.relu(tf.matmul(l3, w4))
l4 = tf.nn.dropout(l4, p_keep_hidden)
return tf.matmul(l4, wo, name="pyx")
X = tf.placeholder("float", [None, size1, size2, size3], name="X")
Y = tf.placeholder("float", [None, 1], name="Y")
py_x = model(X, wo, p_keep_conv, p_keep_hidden)
cost = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(logits=py_x, targets=Y, pos_weight=POS_WEIGHT))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
with tf.Session() as sess:
batch_x, batch_y = read_file('train.dat', 10)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
for step in range(NUM_TRAIN_BATCHES):
x, y = sess.run([batch_x, batch_y])
_, costval = sess.run([train_op, cost], feed_dict={X: x, Y: y, p_keep_conv: 0.8, p_keep_hidden: 0.5})
if step % 100 == 0
print("Step %d, cost %1.5f" % (step, cost_value))
saver.save(sess, './train.model', global_step=step)
上面的代碼下面打印並嘗試運用相同的數據:
model_no = 300
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./train.model-%d.meta' % (model_no))
saver.restore(sess, tf.train.latest_checkpoint('./'))
batch_x, batch_y = read_file('train.dat', 10)
sess.run(tf.global_variables_initializer())
x, y = sess.run([batch_x, batch_y])
cost_value = sess.run(cost, feed_dict={"X:0": x, "Y:0": y, p_keep_conv: 0.8, p_keep_hidden: 0.5})
print("cost %1.5f" % (cost_value))
和上面打印:
cost loss 1.10895
這對模型訓練的第一次迭代非常接近。
另外還有一件事,我不能左右我的頭是檢查點文件,其中只包含如下:
model_checkpoint_path: "train.model-300"
all_model_checkpoint_paths: "train.model-0"
all_model_checkpoint_paths: "train.model-100"
all_model_checkpoint_paths: "train.model-200"
all_model_checkpoint_paths: "train.model-300"
它是如何幫助的,什麼是背後的呼喚saver.restore(sess, tf.train.latest_checkpoint('./'))
如果檢查點僅包含路徑的想法模型文件和我明確加載一個特定的模型?
在附註中,對於測試,通常要將所有dropout圖層的保留概率設置爲1. – aseipel
好吧,我已將sess.run(tf.global_variables_initializer())放在保存程序之前= .. line和成本仍然是1.10919 –
沒有什麼明顯的,除了在訓練中缺乏隨機抽樣。你確定你的數據是正確的隨機化? – etarion