2017-01-11 71 views
3

我想用tf.train.Saver(),使張量檢查點,這裏是我的代碼片段:如何在Tensorflow的檢查點保存張量?

import tensorflow as tf 

with tf.Graph().as_default(): 
    var = tf.Variable(tf.zeros([10]), name="biases") 
    temp = tf.add(var, 0.1) 
    init_op = tf.global_variables_initializer() 

    saver = tf.train.Saver({'w':temp}) 

    with tf.Session() as sess: 
     sess.run(init_op) 
     print(sess.run(temp)) 

,但得到了一個錯誤如下:

Traceback (most recent call last): 
    File "./test_counter.py", line 61, in <module> 
    saver = tf.train.Saver({'w':temp}) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1043, in __init__ 
    self.build() 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1073, in build 
    restore_sequentially=self._restore_sequentially) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 649, in build 
    saveables = self._ValidateAndSliceInputs(names_to_saveables) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 578, in _ValidateAndSliceInputs 
    variable) 
    TypeError: names_to_saveables must be a dict mapping string names to Tensors/Variables. Not a variable: Tensor("TransformFeatureToIndex:0", shape=(100,), dtype=string) 

我想到的一種方式是通過sess.run(temp)存儲客戶端的張量並保存,但有沒有更重要的方法?

回答

4

temp不是tf.Variable,而是一個操作。它「沒有」值或狀​​態,它只是圖中的一個節點。如果要明確保存添加到var的結果,則可以通過tf.assigntemp分配給另一個變量,並保存該其他變量。更簡單的方法可能是節省var(或整個會話),並在恢復之後再次評估temp

+0

謝謝,tf.assign的作品!對於我的真實情況,我使用了比tf.add複雜得多的自定義操作,因此評估'temp'非常昂貴,tf.assign很適合我的情況。 –

相關問題