1
我有一個保存的圖形定義,該圖形定義與tf.train.import_meta_graph
一起導入。該圖包含py_func
op,它不可序列化。我可以定義和分配python函數到這個操作而不用從頭開始構建圖表嗎?在TensorFlow中導入GraphDef後設置py_func操作
我有一個保存的圖形定義,該圖形定義與tf.train.import_meta_graph
一起導入。該圖包含py_func
op,它不可序列化。我可以定義和分配python函數到這個操作而不用從頭開始構建圖表嗎?在TensorFlow中導入GraphDef後設置py_func操作
這是可能的,但可能有點脆弱。特別是,pyfunc需要按照它們在原始圖中定義的順序重新定義(以便它們在FuncRegistry中具有相同的標識符)。
一個例子。我們可以定義包括py_func圖:
import tensorflow as tf
def my_py_func(x):
return 13. * x + 2.
def train_model():
with tf.Graph().as_default():
some_input = tf.constant([[1., 2., 3., 4.],
[5., 6., 7., 8.]])
after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32,
name="my_py_func")
coefficient = tf.get_variable(
"coefficient",
shape=[])
bias = tf.get_variable(
"bias",
shape=[])
loss = tf.reduce_sum((coefficient * some_input + bias - after_py_func) ** 2)
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = tf.group(tf.train.AdamOptimizer(0.1).minimize(loss),
tf.assign_add(global_step, 1))
# Make it easy to retreive things we care about when the metagraph is reloaded.
tf.add_to_collection('useful_ops', bias)
tf.add_to_collection('useful_ops', coefficient)
tf.add_to_collection('useful_ops', loss)
tf.add_to_collection('useful_ops', train_op)
tf.add_to_collection('useful_ops', global_step)
tf.add_to_collection('useful_ops', some_input)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as session:
session.run(init_op)
for i in range(5000):
(_, evaled_loss, evaled_coefficient, evaled_bias,
evaled_global_step) = session.run(
[train_op, loss, coefficient, bias, global_step])
if i % 1000 == 0:
print(evaled_global_step, evaled_loss, evaled_coefficient,
evaled_bias)
saver.save(session, "./trained_pyfunc_model", global_step=global_step)
這做一些基礎訓練(匹配的py_func中發現的線性函數):
1 37350.4 -0.0934748 0.193026
1001 19.2717 12.3749 5.40368
2001 0.108373 12.9532 2.2548
3001 8.28227e-06 12.9996 2.00222
4001 3.77258e-09 13.0 2.00004
如果我們那麼,在新的Python會話,嘗試加載元圖,而無需重新定義pyfunc,我們得到一個錯誤:
def load_model():
with tf.Graph().as_default():
saver = tf.train.import_meta_graph("./trained_pyfunc_model-5000.meta")
bias, coefficient, loss, train_op, global_step, some_input = tf.get_collection('useful_ops')
#after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32,
# name="my_py_func")
with tf.Session() as session:
saver.restore(session, "./trained_pyfunc_model-5000")
(_, evaled_loss, evaled_coefficient, evaled_bias,
evaled_global_step) = session.run(
[train_op, loss, coefficient, bias, global_step])
print("Restored: ", evaled_global_step, evaled_loss, evaled_coefficient, evaled_bias)
UnknownError (see above for traceback): KeyError: 'pyfunc_0'
然而,只要日Ëpy_funcs以相同的順序定義並具有相同的實現,我們應該罰款:
def load_model():
with tf.Graph().as_default():
saver = tf.train.import_meta_graph("./trained_pyfunc_model-5000.meta")
bias, coefficient, loss, train_op, global_step, some_input = tf.get_collection('useful_ops')
after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32,
name="my_py_func")
with tf.Session() as session:
saver.restore(session, "./trained_pyfunc_model-5000")
(_, evaled_loss, evaled_coefficient, evaled_bias,
evaled_global_step) = session.run(
[train_op, loss, coefficient, bias, global_step])
print("Restored: ", evaled_global_step, evaled_loss, evaled_coefficient, evaled_bias)
這讓我們繼續訓練,或任何其他我們想與還原的模型做的事:
Restored: 5001 1.77897e-09 13.0 2.00003
請注意,有狀態的py_funcs將很難處理:TensorFlow不保存任何可能與它們關聯的Python變量!