3
我想用tf.train.Saver(),使張量檢查點,這裏是我的代碼片段:如何在Tensorflow的檢查點保存張量?
import tensorflow as tf
with tf.Graph().as_default():
var = tf.Variable(tf.zeros([10]), name="biases")
temp = tf.add(var, 0.1)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver({'w':temp})
with tf.Session() as sess:
sess.run(init_op)
print(sess.run(temp))
,但得到了一個錯誤如下:
Traceback (most recent call last):
File "./test_counter.py", line 61, in <module>
saver = tf.train.Saver({'w':temp})
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1043, in __init__
self.build()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1073, in build
restore_sequentially=self._restore_sequentially)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 649, in build
saveables = self._ValidateAndSliceInputs(names_to_saveables)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 578, in _ValidateAndSliceInputs
variable)
TypeError: names_to_saveables must be a dict mapping string names to Tensors/Variables. Not a variable: Tensor("TransformFeatureToIndex:0", shape=(100,), dtype=string)
我想到的一種方式是通過sess.run(temp)存儲客戶端的張量並保存,但有沒有更重要的方法?
謝謝,tf.assign的作品!對於我的真實情況,我使用了比tf.add複雜得多的自定義操作,因此評估'temp'非常昂貴,tf.assign很適合我的情況。 –