我想恢復VGG_19的模型參數,它用作附加的新初始化圖形的特徵提取器,並訓練分佈式設置中的所有內容。如何使用`MonitoredTrainingSession` /`Scaffold`微調模型
如果我使用slim.learning.train
,一切正常,但我無法使其與tf.train.MonitoredTrainingSession
要求的Scaffold
一起使用。如果我通過一個restore_fn
(創建使用tf.contrib.framework.assign_from_checkpoint_fn
as in documentaiton)作爲init_fn
到Scaffold
我得到 TypeError: callback() takes 1 positional argument but 2 were given
我試圖通過傳遞lambda scaffold, sess: restore_fn(sess)
「固定」它。
如果我嘗試創建一個恢復操作,並把它作爲init_op
(與tf.contrib.slim.assign_from_checkpoint
創建我越來越
INFO:tensorflow:Create CheckpointSaverHook.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
267 self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 268 fetch, allow_tensor=True, allow_operation=True))
269 except TypeError as e:
/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operatio
n)
2608 if self._finalized:
-> 2609 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
2610
/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_
operation)
2700 raise TypeError("Can not convert a %s into a %s."
-> 2701 % (type(obj).__name__, types_str))
2702
TypeError: Can not convert a ndarray into a Tensor or Operation.
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
/ScanAvoidanceML/ScanAvoidanceML/datasets/project_daphnis/train.py in <module>()
129 )
130 FLAGS, unparsed = parser.parse_known_args()
--> 131 tf.app.run(main=train, argv=[sys.argv[0]] + unparsed)
/opt/conda/lib/python3.6/site-packages/tensorflow/python/platform/app.py in run(main, argv)
46 # Call the main function, passing through any arguments
47 # to the final program.
---> 48 _sys.exit(main(_sys.argv[:1] + flags_passthrough))
49
50
/ScanAvoidanceML/ScanAvoidanceML/datasets/project_daphnis/train.py in train(_)
83 scaffold=tf.train.Scaffold(
84 init_op=restore_op,
---> 85 summary_op=tf.summary.merge_all())) as mon_sess:
86 while not mon_sess.should_stop():
87 # Run a training step asynchronously.
/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in MonitoredTrainingSession(master, is_chief,
checkpoint_dir, scaffold, hooks, chief_only_hooks, save_checkpoint_secs, save_summaries_steps, save_summaries_secs, config, stop_grac
e_period_secs, log_step_count_steps)
351 all_hooks.extend(hooks)
352 return MonitoredSession(session_creator=session_creator, hooks=all_hooks,
--> 353 stop_grace_period_secs=stop_grace_period_secs)
354
355
/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in __init__(self, session_creator, hooks, stop
_grace_period_secs)
654 super(MonitoredSession, self).__init__(
655 session_creator, hooks, should_recover=True,
--> 656 stop_grace_period_secs=stop_grace_period_secs)
657
658
/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in __init__(self, session_creator, hooks, shou
ld_recover, stop_grace_period_secs)
476 stop_grace_period_secs=stop_grace_period_secs)
477 if should_recover:
--> 478 self._sess = _RecoverableSession(self._coordinated_creator)
479 else:
480 self._sess = self._coordinated_creator.create_session()
/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in __init__(self, sess_creator)
828 """
829 self._sess_creator = sess_creator
--> 830 _WrappedSession.__init__(self, self._create_session())
831
832 def _create_session(self):
/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in _create_session(self)
833 while True:
834 try:
--> 835 return self._sess_creator.create_session()
836 except _PREEMPTION_ERRORS as e:
837 logging.info('An error was raised while a session was being created. '
/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in create_session(self)
537 """Creates a coordinated session."""
538 # Keep the tf_sess for unit testing.
--> 539 self.tf_sess = self._session_creator.create_session()
540 # We don't want coordinator to suppress any exception.
541 self.coord = coordinator.Coordinator(clean_stop_exception_types=[])
/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in create_session(self)
411 init_op=self._scaffold.init_op,
412 init_feed_dict=self._scaffold.init_feed_dict,
--> 413 init_fn=self._scaffold.init_fn)
414
415
/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/session_manager.py in prepare_session(self, master, init_op, saver,
checkpoint_dir, checkpoint_filename_with_path, wait_for_checkpoint, max_wait_secs, config, init_feed_dict, init_fn)
277 "init_fn or local_init_op was given")
278 if init_op is not None:
--> 279 sess.run(init_op, feed_dict=init_feed_dict)
280 if init_fn:
281 init_fn(sess)
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
894 try:
895 result = self._run(None, fetches, feed_dict, options_ptr,
--> 896 run_metadata_ptr)
897 if run_metadata:
898 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_met
adata)
1107 # Create a fetch handler to take care of the structure of fetches.
1108 fetch_handler = _FetchHandler(
-> 1109 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1110
1111 # Run request and get response.
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
409 """
410 with graph.as_default():
--> 411 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
412 self._fetches = []
413 self._targets = []
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
229 elif isinstance(fetch, (list, tuple)):
230 # NOTE(touts): This is also the code path for namedtuples.
--> 231 return _ListFetchMapper(fetch)
232 elif isinstance(fetch, dict):
233 return _DictFetchMapper(fetch)
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches)
336 """
337 self._fetch_type = type(fetches)
--> 338 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
339 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
340
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in <listcomp>(.0)
336 """
337 self._fetch_type = type(fetches)
--> 338 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
339 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
340
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
231 return _ListFetchMapper(fetch)
232 elif isinstance(fetch, dict):
--> 233 return _DictFetchMapper(fetch)
234 else:
235 # Look for a handler in the registered expansions.
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches)
369 self._keys = fetches.keys()
370 self._mappers = [_FetchMapper.for_fetch(fetch)
--> 371 for fetch in fetches.values()]
372 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
373
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
237 if isinstance(fetch, tensor_type):
238 fetches, contraction_fn = fetch_fn(fetch)
--> 239 return _ElementFetchMapper(fetches, contraction_fn)
240 # Did not find anything.
241 raise TypeError('Fetch argument %r has invalid type %r' %
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches)
369 self._keys = fetches.keys()
370 self._mappers = [_FetchMapper.for_fetch(fetch)
--> 371 for fetch in fetches.values()]
372 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
373
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in <listcomp>(.0)
369 self._keys = fetches.keys()
370 self._mappers = [_FetchMapper.for_fetch(fetch)
--> 371 for fetch in fetches.values()]
372 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
373
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
237 if isinstance(fetch, tensor_type):
238 fetches, contraction_fn = fetch_fn(fetch)
--> 239 return _ElementFetchMapper(fetches, contraction_fn)
240 # Did not find anything.
241 raise TypeError('Fetch argument %r has invalid type %r' %
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
270 raise TypeError('Fetch argument %r has invalid type %r, '
271 'must be a string or Tensor. (%s)'
--> 272 % (fetch, type(fetch), str(e)))
273 except ValueError as e:
274 raise ValueError('Fetch argument %r cannot be interpreted as a '
TypeError: Fetch argument array([[[[ 0.39416704, -0.08419707, -0.03631314, ..., -0.10720515,
-0.03804016, 0.04690642],
[ 0.46418372, 0.03355668, 0.10245045, ..., -0.06945956,
-0.04020201, 0.04048637],
[ 0.34119523, 0.09563112, 0.0177449 , ..., -0.11436455,
-0.05099866, -0.00299793]],
[[ 0.37740308, -0.07876257, -0.04775979, ..., -0.11827433,
-0.19008617, -0.01889699],
[ 0.41810837, 0.05260524, 0.09755926, ..., -0.09385028,
-0.20492788, -0.0573062 ],
[ 0.33999205, 0.13363543, 0.02129423, ..., -0.13025227,
-0.16508926, -0.06969624]],
[[-0.04594866, -0.11583115, -0.14462094, ..., -0.12290562,
-0.35782176, -0.27979308],
[-0.04806903, -0.00658076, -0.02234544, ..., -0.0878844 ,
-0.3915486 , -0.34632796],
[-0.04484424, 0.06471398, -0.07631404, ..., -0.12629718,
-0.29905206, -0.28253639]]],
[[[ 0.2671299 , -0.07969447, 0.05988706, ..., -0.09225675,
0.31764674, 0.42209673],
[ 0.30511212, 0.05677647, 0.21688674, ..., -0.06828708,
0.3440761 , 0.44033417],
[ 0.23215917, 0.13365699, 0.12134422, ..., -0.1063385 ,
0.28406844, 0.35949969]],
[[ 0.09986369, -0.06240906, 0.07442063, ..., -0.02214639,
0.25912452, 0.42349899],
[ 0.10385381, 0.08851637, 0.2392226 , ..., -0.01210995,
0.27064082, 0.40848857],
[ 0.08978214, 0.18505956, 0.15264879, ..., -0.04266965,
0.25779948, 0.35873157]],
[[-0.34100872, -0.13399366, -0.11510294, ..., -0.11911335,
-0.23109646, -0.19202407],
[-0.37314063, -0.00698938, 0.02153259, ..., -0.09827439,
-0.2535741 , -0.25541356],
[-0.30331427, 0.08002605, -0.03926321, ..., -0.12958746,
-0.19778992, -0.21510386]]],
[[[-0.07573577, -0.07806503, -0.03540679, ..., -0.1208065 ,
0.20088433, 0.09790061],
[-0.07646758, 0.03879711, 0.09974211, ..., -0.08732687,
0.2247974 , 0.10158388],
[-0.07260918, 0.10084777, 0.01313597, ..., -0.12594968,
0.14647409, 0.05009392]],
[[-0.28034249, -0.07094654, -0.0387974 , ..., -0.08843154,
0.18996507, 0.07766484],
[-0.31070709, 0.06031388, 0.10412455, ..., -0.06832542,
0.20279962, 0.05222717],
[-0.246675 , 0.1414054 , 0.02605635, ..., -0.10128672,
0.16340195, 0.02832468]],
[[-0.41602272, -0.11491341, -0.14672887, ..., -0.13079506,
-0.1379628 , -0.26588449],
[-0.46453714, -0.00576723, -0.02660675, ..., -0.10017379,
-0.15603794, -0.32566148],
[-0.33683276, 0.06601517, -0.08144748, ..., -0.13460518,
-0.1342358 , -0.27096185]]]], dtype=float32) has invalid type <class 'numpy.ndarray'>, must be a string or Tensor. (Can not
convert a ndarray into a Tensor or Operation.)
我試圖用一個local_init_op
,也沒有工作 我的代碼:
import sys
import tensorflow as tf
slim = tf.contrib.slim
import argparse
import model as M
import decoder as D
FLAGS = None
def train(_):
vgg_19_ckpt_path='/media/data/projects/project_daphnis/pretrained_models/vgg_19.ckpt'
train_log_dir = "/media/data/projects/project_daphnis/train_log_dir"
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
# Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# Create and start a server for the local task.
server = tf.train.Server(cluster,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index)
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
if not tf.gfile.Exists(train_log_dir):
tf.gfile.MakeDirs(train_log_dir)
# Assigns ops to the local worker by default.
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
# Set up the data loading:
image, c, p, s = \
D.get_training_dataset_data_provider()
image, c, p, s = \
tf.train.batch([image, c, p, s],
batch_size=16)
# Define the model:
predictions, loss, end_points = M.model_as_in_paper(
image, c, p, s
)
restore_fn = tf.contrib.framework.assign_from_checkpoint_fn(
vgg_19_ckpt_path,
var_list=slim.get_variables_to_restore(include=["vgg_19"],
exclude=[
'vgg_19/conv4_3_X',
'vgg_19/conv4_4_X']
)
)
# Specify the optimization scheme:
optimizer = tf.train.AdamOptimizer(learning_rate=.00001)
# create_train_op that ensures that when we evaluate it to get the loss,
# the update_ops are done and the gradient updates are computed.
train_op = slim.learning.create_train_op(loss, optimizer)
tf.summary.scalar("losses/total_loss", loss)
# The StopAtStepHook handles stopping after running given steps.
hooks = [tf.train.StopAtStepHook(last_step=1000000)]
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(
master=server.target,
is_chief=(FLAGS.task_index == 0),
checkpoint_dir=train_log_dir,
hooks=hooks,
scaffold=tf.train.Scaffold(
init_fn=restore_fn,
summary_op=tf.summary.merge_all())) as mon_sess:
while not mon_sess.should_stop():
# Run a training step asynchronously.
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.
# mon_sess.run handles AbortedError in case of preempted PS.
mon_sess.run(train_op)
#
# # Actually runs training.
# slim.learning.train(train_tensor,
# train_log_dir,
# init_fn=restore_fn,
# summary_op=tf.summary.merge_all(),
# is_chief=False)
if __name__ == "__main__":
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
# Flags for defining the tf.train.ClusterSpec
parser.add_argument(
"--ps_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--worker_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--job_name",
type=str,
default="",
help="One of 'ps', 'worker'"
)
# Flags for defining the tf.train.Server
parser.add_argument(
"--task_index",
type=int,
default=0,
help="Index of task within the job"
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=train, argv=[sys.argv[0]] + unparsed)