2017-07-14 77 views
0

我想恢復VGG_19的模型參數,它用作附加的新初始化圖形的特徵​​提取器,並訓練分佈式設置中的所有內容。如何使用`MonitoredTrainingSession` /`Scaffold`微調模型

如果我使用slim.learning.train,一切正常,但我無法使其與tf.train.MonitoredTrainingSession要求的Scaffold一起使用。如果我通過一個restore_fn(創建使用tf.contrib.framework.assign_from_checkpoint_fnas in documentaiton)作爲init_fnScaffold我得到 TypeError: callback() takes 1 positional argument but 2 were given

我試圖通過傳遞lambda scaffold, sess: restore_fn(sess)「固定」它。

如果我嘗試創建一個恢復操作,並把它作爲init_op(與tf.contrib.slim.assign_from_checkpoint創建我越來越

INFO:tensorflow:Create CheckpointSaverHook. 
--------------------------------------------------------------------------- 
TypeError         Traceback (most recent call last) 
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn) 
    267   self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 268    fetch, allow_tensor=True, allow_operation=True)) 
    269  except TypeError as e: 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operatio 
n) 
    2608  if self._finalized: 
-> 2609  return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 
    2610 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_ 
operation) 
    2700  raise TypeError("Can not convert a %s into a %s." 
-> 2701      % (type(obj).__name__, types_str)) 
    2702 

TypeError: Can not convert a ndarray into a Tensor or Operation. 

During handling of the above exception, another exception occurred: 

TypeError         Traceback (most recent call last) 
/ScanAvoidanceML/ScanAvoidanceML/datasets/project_daphnis/train.py in <module>() 
    129  ) 
    130   FLAGS, unparsed = parser.parse_known_args() 
--> 131   tf.app.run(main=train, argv=[sys.argv[0]] + unparsed) 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/platform/app.py in run(main, argv) 
    46 # Call the main function, passing through any arguments 
    47 # to the final program. 
---> 48 _sys.exit(main(_sys.argv[:1] + flags_passthrough)) 
    49 
    50 

/ScanAvoidanceML/ScanAvoidanceML/datasets/project_daphnis/train.py in train(_) 
    83     scaffold=tf.train.Scaffold(
    84      init_op=restore_op, 
---> 85      summary_op=tf.summary.merge_all())) as mon_sess: 
    86    while not mon_sess.should_stop(): 
    87     # Run a training step asynchronously. 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in MonitoredTrainingSession(master, is_chief, 
checkpoint_dir, scaffold, hooks, chief_only_hooks, save_checkpoint_secs, save_summaries_steps, save_summaries_secs, config, stop_grac 
e_period_secs, log_step_count_steps) 
    351  all_hooks.extend(hooks) 
    352 return MonitoredSession(session_creator=session_creator, hooks=all_hooks, 
--> 353       stop_grace_period_secs=stop_grace_period_secs) 
    354 
    355 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in __init__(self, session_creator, hooks, stop 
_grace_period_secs) 
    654  super(MonitoredSession, self).__init__(
    655   session_creator, hooks, should_recover=True, 
--> 656   stop_grace_period_secs=stop_grace_period_secs) 
    657 
    658 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in __init__(self, session_creator, hooks, shou 
ld_recover, stop_grace_period_secs) 
    476   stop_grace_period_secs=stop_grace_period_secs) 
    477  if should_recover: 
--> 478  self._sess = _RecoverableSession(self._coordinated_creator) 
    479  else: 
    480  self._sess = self._coordinated_creator.create_session() 


/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in __init__(self, sess_creator) 
    828  """ 
    829  self._sess_creator = sess_creator 
--> 830  _WrappedSession.__init__(self, self._create_session()) 
    831 
    832 def _create_session(self): 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in _create_session(self) 
    833  while True: 
    834  try: 
--> 835   return self._sess_creator.create_session() 
    836  except _PREEMPTION_ERRORS as e: 
    837   logging.info('An error was raised while a session was being created. ' 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in create_session(self) 
    537  """Creates a coordinated session.""" 
    538  # Keep the tf_sess for unit testing. 
--> 539  self.tf_sess = self._session_creator.create_session() 
    540  # We don't want coordinator to suppress any exception. 
    541  self.coord = coordinator.Coordinator(clean_stop_exception_types=[]) 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in create_session(self) 
    411   init_op=self._scaffold.init_op, 
    412   init_feed_dict=self._scaffold.init_feed_dict, 
--> 413   init_fn=self._scaffold.init_fn) 
    414 
    415 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/session_manager.py in prepare_session(self, master, init_op, saver, 
checkpoint_dir, checkpoint_filename_with_path, wait_for_checkpoint, max_wait_secs, config, init_feed_dict, init_fn) 
    277       "init_fn or local_init_op was given") 
    278  if init_op is not None: 
--> 279   sess.run(init_op, feed_dict=init_feed_dict) 
    280  if init_fn: 
    281   init_fn(sess) 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata) 
    894  try: 
    895  result = self._run(None, fetches, feed_dict, options_ptr, 
--> 896       run_metadata_ptr) 
    897  if run_metadata: 
    898   proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_met 
adata) 
    1107  # Create a fetch handler to take care of the structure of fetches. 
    1108  fetch_handler = _FetchHandler(
-> 1109   self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles) 
    1110 
    1111  # Run request and get response. 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles) 
    409  """ 
    410  with graph.as_default(): 
--> 411  self._fetch_mapper = _FetchMapper.for_fetch(fetches) 
    412  self._fetches = [] 
    413  self._targets = [] 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch) 
    229  elif isinstance(fetch, (list, tuple)): 
    230  # NOTE(touts): This is also the code path for namedtuples. 
--> 231  return _ListFetchMapper(fetch) 
    232  elif isinstance(fetch, dict): 
    233  return _DictFetchMapper(fetch) 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches) 
    336  """ 
    337  self._fetch_type = type(fetches) 
--> 338  self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 
    339  self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 
    340 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in <listcomp>(.0) 
    336  """ 
    337  self._fetch_type = type(fetches) 
--> 338  self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 
    339  self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 
    340 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch) 
    231  return _ListFetchMapper(fetch) 
    232  elif isinstance(fetch, dict): 
--> 233  return _DictFetchMapper(fetch) 
    234  else: 
    235  # Look for a handler in the registered expansions. 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches) 
    369  self._keys = fetches.keys() 
    370  self._mappers = [_FetchMapper.for_fetch(fetch) 
--> 371      for fetch in fetches.values()] 
    372  self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 
    373 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch) 
    237   if isinstance(fetch, tensor_type): 
    238   fetches, contraction_fn = fetch_fn(fetch) 
--> 239   return _ElementFetchMapper(fetches, contraction_fn) 
    240  # Did not find anything. 
    241  raise TypeError('Fetch argument %r has invalid type %r' % 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches) 
    369  self._keys = fetches.keys() 
    370  self._mappers = [_FetchMapper.for_fetch(fetch) 
--> 371      for fetch in fetches.values()] 
    372  self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 
    373 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in <listcomp>(.0) 
    369  self._keys = fetches.keys() 
    370  self._mappers = [_FetchMapper.for_fetch(fetch) 
--> 371      for fetch in fetches.values()] 
    372  self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 
    373 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch) 
    237   if isinstance(fetch, tensor_type): 
    238   fetches, contraction_fn = fetch_fn(fetch) 
--> 239   return _ElementFetchMapper(fetches, contraction_fn) 
    240  # Did not find anything. 
    241  raise TypeError('Fetch argument %r has invalid type %r' % 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn) 
    270   raise TypeError('Fetch argument %r has invalid type %r, ' 
    271       'must be a string or Tensor. (%s)' 
--> 272       % (fetch, type(fetch), str(e))) 
    273  except ValueError as e: 
    274   raise ValueError('Fetch argument %r cannot be interpreted as a ' 

TypeError: Fetch argument array([[[[ 0.39416704, -0.08419707, -0.03631314, ..., -0.10720515, 
      -0.03804016, 0.04690642], 
     [ 0.46418372, 0.03355668, 0.10245045, ..., -0.06945956, 
      -0.04020201, 0.04048637], 
     [ 0.34119523, 0.09563112, 0.0177449 , ..., -0.11436455, 
      -0.05099866, -0.00299793]], 

     [[ 0.37740308, -0.07876257, -0.04775979, ..., -0.11827433, 
      -0.19008617, -0.01889699], 
     [ 0.41810837, 0.05260524, 0.09755926, ..., -0.09385028, 
      -0.20492788, -0.0573062 ], 
     [ 0.33999205, 0.13363543, 0.02129423, ..., -0.13025227, 
      -0.16508926, -0.06969624]], 

     [[-0.04594866, -0.11583115, -0.14462094, ..., -0.12290562, 
      -0.35782176, -0.27979308], 
     [-0.04806903, -0.00658076, -0.02234544, ..., -0.0878844 , 
      -0.3915486 , -0.34632796], 
     [-0.04484424, 0.06471398, -0.07631404, ..., -0.12629718, 
      -0.29905206, -0.28253639]]], 


     [[[ 0.2671299 , -0.07969447, 0.05988706, ..., -0.09225675, 
      0.31764674, 0.42209673], 
     [ 0.30511212, 0.05677647, 0.21688674, ..., -0.06828708, 
      0.3440761 , 0.44033417], 
     [ 0.23215917, 0.13365699, 0.12134422, ..., -0.1063385 , 
      0.28406844, 0.35949969]], 

     [[ 0.09986369, -0.06240906, 0.07442063, ..., -0.02214639, 
      0.25912452, 0.42349899], 
     [ 0.10385381, 0.08851637, 0.2392226 , ..., -0.01210995, 
      0.27064082, 0.40848857], 
     [ 0.08978214, 0.18505956, 0.15264879, ..., -0.04266965, 
      0.25779948, 0.35873157]], 

     [[-0.34100872, -0.13399366, -0.11510294, ..., -0.11911335, 
      -0.23109646, -0.19202407], 
     [-0.37314063, -0.00698938, 0.02153259, ..., -0.09827439, 
      -0.2535741 , -0.25541356], 
     [-0.30331427, 0.08002605, -0.03926321, ..., -0.12958746, 
      -0.19778992, -0.21510386]]], 


     [[[-0.07573577, -0.07806503, -0.03540679, ..., -0.1208065 , 
      0.20088433, 0.09790061], 
     [-0.07646758, 0.03879711, 0.09974211, ..., -0.08732687, 
      0.2247974 , 0.10158388], 
     [-0.07260918, 0.10084777, 0.01313597, ..., -0.12594968, 
      0.14647409, 0.05009392]], 

     [[-0.28034249, -0.07094654, -0.0387974 , ..., -0.08843154, 
      0.18996507, 0.07766484], 
     [-0.31070709, 0.06031388, 0.10412455, ..., -0.06832542, 
      0.20279962, 0.05222717], 
     [-0.246675 , 0.1414054 , 0.02605635, ..., -0.10128672, 
      0.16340195, 0.02832468]], 

     [[-0.41602272, -0.11491341, -0.14672887, ..., -0.13079506, 
      -0.1379628 , -0.26588449], 
     [-0.46453714, -0.00576723, -0.02660675, ..., -0.10017379, 
      -0.15603794, -0.32566148], 
     [-0.33683276, 0.06601517, -0.08144748, ..., -0.13460518, 
      -0.1342358 , -0.27096185]]]], dtype=float32) has invalid type <class 'numpy.ndarray'>, must be a string or Tensor. (Can not 
convert a ndarray into a Tensor or Operation.) 

我試圖用一個local_init_op,也沒有工作 我的代碼:

import sys 
import tensorflow as tf 
slim = tf.contrib.slim 
import argparse 
import model as M 
import decoder as D 


FLAGS = None 


def train(_): 
    vgg_19_ckpt_path='/media/data/projects/project_daphnis/pretrained_models/vgg_19.ckpt' 
    train_log_dir = "/media/data/projects/project_daphnis/train_log_dir" 

    ps_hosts = FLAGS.ps_hosts.split(",") 
    worker_hosts = FLAGS.worker_hosts.split(",") 

    # Create a cluster from the parameter server and worker hosts. 
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) 

    # Create and start a server for the local task. 
    server = tf.train.Server(cluster, 
          job_name=FLAGS.job_name, 
          task_index=FLAGS.task_index) 

    if FLAGS.job_name == "ps": 
     server.join() 
    elif FLAGS.job_name == "worker": 
     if not tf.gfile.Exists(train_log_dir): 
      tf.gfile.MakeDirs(train_log_dir) 

     # Assigns ops to the local worker by default. 
     with tf.device(tf.train.replica_device_setter(
       worker_device="/job:worker/task:%d" % FLAGS.task_index, 
       cluster=cluster)): 

      # Set up the data loading: 
      image, c, p, s = \ 
       D.get_training_dataset_data_provider() 

      image, c, p, s = \ 
       tf.train.batch([image, c, p, s], 
           batch_size=16) 

      # Define the model: 
      predictions, loss, end_points = M.model_as_in_paper(
       image, c, p, s 
      ) 

      restore_fn = tf.contrib.framework.assign_from_checkpoint_fn(
       vgg_19_ckpt_path, 
       var_list=slim.get_variables_to_restore(include=["vgg_19"], 
                 exclude=[ 
                  'vgg_19/conv4_3_X', 
                  'vgg_19/conv4_4_X'] 
                 ) 
      ) 


      # Specify the optimization scheme: 
      optimizer = tf.train.AdamOptimizer(learning_rate=.00001) 

      # create_train_op that ensures that when we evaluate it to get the loss, 
      # the update_ops are done and the gradient updates are computed. 
      train_op = slim.learning.create_train_op(loss, optimizer) 
     tf.summary.scalar("losses/total_loss", loss) 

     # The StopAtStepHook handles stopping after running given steps. 
     hooks = [tf.train.StopAtStepHook(last_step=1000000)] 

     # The MonitoredTrainingSession takes care of session initialization, 
     # restoring from a checkpoint, saving to a checkpoint, and closing when done 
     # or an error occurs. 
     with tf.train.MonitoredTrainingSession(
       master=server.target, 
       is_chief=(FLAGS.task_index == 0), 
       checkpoint_dir=train_log_dir, 
       hooks=hooks, 
       scaffold=tf.train.Scaffold(
        init_fn=restore_fn, 
        summary_op=tf.summary.merge_all())) as mon_sess: 
      while not mon_sess.should_stop(): 
       # Run a training step asynchronously. 
       # See `tf.train.SyncReplicasOptimizer` for additional details on how to 
       # perform *synchronous* training. 
       # mon_sess.run handles AbortedError in case of preempted PS. 
       mon_sess.run(train_op) 
     # 
     # # Actually runs training. 
     # slim.learning.train(train_tensor, 
     #      train_log_dir, 
     #      init_fn=restore_fn, 
     #      summary_op=tf.summary.merge_all(), 
     #      is_chief=False) 

if __name__ == "__main__": 
    if __name__ == "__main__": 
     parser = argparse.ArgumentParser() 
     parser.register("type", "bool", lambda v: v.lower() == "true") 
     # Flags for defining the tf.train.ClusterSpec 
     parser.add_argument(
      "--ps_hosts", 
      type=str, 
      default="", 
      help="Comma-separated list of hostname:port pairs" 
     ) 
     parser.add_argument(
      "--worker_hosts", 
      type=str, 
      default="", 
      help="Comma-separated list of hostname:port pairs" 
     ) 
     parser.add_argument(
      "--job_name", 
      type=str, 
      default="", 
      help="One of 'ps', 'worker'" 
     ) 
     # Flags for defining the tf.train.Server 
     parser.add_argument(
      "--task_index", 
      type=int, 
      default=0, 
      help="Index of task within the job" 
     ) 
     FLAGS, unparsed = parser.parse_known_args() 
     tf.app.run(main=train, argv=[sys.argv[0]] + unparsed) 

回答

0

答案是使用保護程序恢復的參數和包裹saver.restore功能,因此它可以被用作Scaffoldinit_fn。釷是包裝必須採取兩個參數:scaffoldsess,其中sess用於恢復參數和scaffold被扔掉。

完整代碼:

import sys 
import tensorflow as tf 
slim = tf.contrib.slim 
import argparse 
import model as M 
import decoder as D 


FLAGS = None 


def train(_): 
    vgg_19_ckpt_path='/media/data/projects/project_daphnis/pretrained_models/vgg_19.ckpt' 
    train_log_dir = "/media/data/projects/project_daphnis/train_log_dir" 

    ps_hosts = FLAGS.ps_hosts.split(",") 
    worker_hosts = FLAGS.worker_hosts.split(",") 

    # Create a cluster from the parameter server and worker hosts. 
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) 

    # Create and start a server for the local task. 
    server = tf.train.Server(cluster, 
          job_name=FLAGS.job_name, 
          task_index=FLAGS.task_index) 

    if FLAGS.job_name == "ps": 
     server.join() 
    elif FLAGS.job_name == "worker": 
     if not tf.gfile.Exists(train_log_dir): 
      tf.gfile.MakeDirs(train_log_dir) 

     # Assigns ops to the local worker by default. 
     with tf.device(tf.train.replica_device_setter(
       worker_device="/job:worker/task:%d" % FLAGS.task_index, 
       cluster=cluster)): 

      # Set up the data loading: 
      image, c, p, s = \ 
       D.get_training_dataset_data_provider() 

      image, c, p, s = \ 
       tf.train.batch([image, c, p, s], 
           batch_size=16) 

      # Define the model: 
      predictions, loss, end_points = M.model_as_in_paper(
       image, c, p, s 
      ) 

      values_to_restore = slim.get_variables_to_restore(
       include=["vgg_19"], 
       exclude=[ 
        'vgg_19/conv4_3_X', 
        'vgg_19/conv4_4_X'] 
     ) 


      # Specify the optimization scheme: 
      optimizer = tf.train.AdamOptimizer(learning_rate=.00001) 

      # create_train_op that ensures that when we evaluate it to get the loss, 
      # the update_ops are done and the gradient updates are computed. 
      train_op = slim.learning.create_train_op(loss, optimizer) 
     tf.summary.scalar("losses/total_loss", loss) 

     # The StopAtStepHook handles stopping after running given steps. 
     hooks = [tf.train.StopAtStepHook(last_step=1000000)] 

     pre_train_saver = tf.train.Saver(values_to_restore) 

     def load_pretrain(scaffold, sess): 
      pre_train_saver.restore(sess, 
            vgg_19_ckpt_path) 

     # The MonitoredTrainingSession takes care of session initialization, 
     # restoring from a checkpoint, saving to a checkpoint, and closing when done 
     # or an error occurs. 
     with tf.train.MonitoredTrainingSession(
       master=server.target, 
       is_chief=(FLAGS.task_index == 0), 
       checkpoint_dir=train_log_dir, 
       hooks=hooks, 
       scaffold=tf.train.Scaffold(
        init_fn=load_pretrain, 
        summary_op=tf.summary.merge_all())) as mon_sess: 

      while not mon_sess.should_stop(): 
       # Run a training step asynchronously. 
       # See `tf.train.SyncReplicasOptimizer` for additional details on how to 
       # perform *synchronous* training. 
       # mon_sess.run handles AbortedError in case of preempted PS. 
       mon_sess.run(train_op) 

if __name__ == "__main__": 
    if __name__ == "__main__": 
     parser = argparse.ArgumentParser() 
     parser.register("type", "bool", lambda v: v.lower() == "true") 
     # Flags for defining the tf.train.ClusterSpec 
     parser.add_argument(
      "--ps_hosts", 
      type=str, 
      default="", 
      help="Comma-separated list of hostname:port pairs" 
     ) 
     parser.add_argument(
      "--worker_hosts", 
      type=str, 
      default="", 
      help="Comma-separated list of hostname:port pairs" 
     ) 
     parser.add_argument(
      "--job_name", 
      type=str, 
      default="", 
      help="One of 'ps', 'worker'" 
     ) 
     # Flags for defining the tf.train.Server 
     parser.add_argument(
      "--task_index", 
      type=int, 
      default=0, 
      help="Index of task within the job" 
     ) 
     FLAGS, unparsed = parser.parse_known_args() 
     tf.app.run(main=train, argv=[sys.argv[0]] + unparsed)