2017-04-26 52 views

回答

4

這是可能的,但可能有點脆弱。特別是,pyfunc需要按照它們在原始圖中定義的順序重新定義(以便它們在FuncRegistry中具有相同的標識符)。

一個例子。我們可以定義包括py_func圖:

import tensorflow as tf 

def my_py_func(x): 
    return 13. * x + 2. 

def train_model(): 
    with tf.Graph().as_default(): 
    some_input = tf.constant([[1., 2., 3., 4.], 
           [5., 6., 7., 8.]]) 
    after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32, 
           name="my_py_func") 
    coefficient = tf.get_variable(
     "coefficient", 
     shape=[]) 
    bias = tf.get_variable(
     "bias", 
     shape=[]) 
    loss = tf.reduce_sum((coefficient * some_input + bias - after_py_func) ** 2) 
    global_step = tf.contrib.framework.get_or_create_global_step() 
    train_op = tf.group(tf.train.AdamOptimizer(0.1).minimize(loss), 
         tf.assign_add(global_step, 1)) 
    # Make it easy to retreive things we care about when the metagraph is reloaded. 
    tf.add_to_collection('useful_ops', bias) 
    tf.add_to_collection('useful_ops', coefficient) 
    tf.add_to_collection('useful_ops', loss) 
    tf.add_to_collection('useful_ops', train_op) 
    tf.add_to_collection('useful_ops', global_step) 
    tf.add_to_collection('useful_ops', some_input) 
    init_op = tf.global_variables_initializer() 
    saver = tf.train.Saver() 
    with tf.Session() as session: 
     session.run(init_op) 
     for i in range(5000): 
     (_, evaled_loss, evaled_coefficient, evaled_bias, 
     evaled_global_step) = session.run(
      [train_op, loss, coefficient, bias, global_step]) 
     if i % 1000 == 0: 
      print(evaled_global_step, evaled_loss, evaled_coefficient, 
       evaled_bias) 
     saver.save(session, "./trained_pyfunc_model", global_step=global_step) 

這做一些基礎訓練(匹配的py_func中發現的線性函數):

1 37350.4 -0.0934748 0.193026 
1001 19.2717 12.3749 5.40368 
2001 0.108373 12.9532 2.2548 
3001 8.28227e-06 12.9996 2.00222 
4001 3.77258e-09 13.0 2.00004 

如果我們那麼,在新的Python會話,嘗試加載元圖,而無需重新定義pyfunc,我們得到一個錯誤:

def load_model(): 
    with tf.Graph().as_default(): 
    saver = tf.train.import_meta_graph("./trained_pyfunc_model-5000.meta") 
    bias, coefficient, loss, train_op, global_step, some_input = tf.get_collection('useful_ops') 
    #after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32, 
    #       name="my_py_func") 
    with tf.Session() as session: 
     saver.restore(session, "./trained_pyfunc_model-5000") 
     (_, evaled_loss, evaled_coefficient, evaled_bias, 
     evaled_global_step) = session.run(
      [train_op, loss, coefficient, bias, global_step]) 
     print("Restored: ", evaled_global_step, evaled_loss, evaled_coefficient, evaled_bias) 

UnknownError (see above for traceback): KeyError: 'pyfunc_0'

然而,只要日Ëpy_funcs以相同的順序定義並具有相同的實現,我們應該罰款:

def load_model(): 
    with tf.Graph().as_default(): 
    saver = tf.train.import_meta_graph("./trained_pyfunc_model-5000.meta") 
    bias, coefficient, loss, train_op, global_step, some_input = tf.get_collection('useful_ops') 
    after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32, 
           name="my_py_func") 
    with tf.Session() as session: 
     saver.restore(session, "./trained_pyfunc_model-5000") 
     (_, evaled_loss, evaled_coefficient, evaled_bias, 
     evaled_global_step) = session.run(
      [train_op, loss, coefficient, bias, global_step]) 
     print("Restored: ", evaled_global_step, evaled_loss, evaled_coefficient, evaled_bias) 

這讓我們繼續訓練,或任何其他我們想與還原的模型做的事:

Restored: 5001 1.77897e-09 13.0 2.00003 

請注意,有狀態的py_funcs將很難處理:TensorFlow不保存任何可能與它們關聯的Python變量!