2017-08-21 21 views
0

我爲自定義數據集重新訓練了初始v3模型。 但是在重新訓練後,當我看着TenosorGraph時,我發現添加了一個命名爲reshape的圖層,然後是完全連接圖層。 我必須使用snapdragonneural處理引擎(SNPE)在嵌入式設備上運行模型,但它目前不支持重塑層在DSP上運行。沒有重塑層的Retraining Incpetion v3模型

有沒有一種可能的方式來重新啓動v3而不添加重塑層。 以下是添加了重塑圖層的培訓代碼。

enter code here 
       def create_model_info(architecture): 
    """Given the name of a model architecture, returns information about it. 

    There are different base image recognition pretrained models that can be 
    retrained using transfer learning, and this function translates from the name 
    of a model to the attributes that are needed to download and train with it. 

    Args: 
    architecture: Name of a model architecture. 

    Returns: 
    Dictionary of information about the model, or None if the name isn't 
    recognized 

    Raises: 
    ValueError: If architecture name is unknown. 
    """ 
    architecture = architecture.lower() 
    if architecture == 'inception_v3': 
    # pylint: disable=line-too-long 
    data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 
    # pylint: enable=line-too-long 
    bottleneck_tensor_name = 'pool_3/_reshape:0' 
    bottleneck_tensor_size = 2048 
    input_width = 299 
    input_height = 299 
    input_depth = 3 
    resized_input_tensor_name = 'Mul:0' 
    model_file_name = 'classify_image_graph_def.pb' 
    input_mean = 128 
    input_std = 128 
     elif architecture.startswith('mobilenet_'): 
     parts = architecture.split('_') 
     if len(parts) != 3 and len(parts) != 4: 
      tf.logging.error("Couldn't understand architecture name '%s'", 
          architecture) 
      return None 
     version_string = parts[1] 
     if (version_string != '1.0' and version_string != '0.75' and 
      version_string != '0.50' and version_string != '0.25'): 
      tf.logging.error(
       """"The Mobilenet version should be '1.0', '0.75', '0.50', or '0.25', 
     but found '%s' for architecture '%s'""", 
       version_string, architecture) 
      return None 
     size_string = parts[2] 
     if (size_string != '224' and size_string != '192' and 
      size_string != '160' and size_string != '128'): 
      tf.logging.error(
       """The Mobilenet input size should be '224', '192', '160', or '128', 
    but found '%s' for architecture '%s'""", 
       size_string, architecture) 
      return None 
     if len(parts) == 3: 
      is_quantized = False 
     else: 
      if parts[3] != 'quantized': 
      tf.logging.error(
       "Couldn't understand architecture suffix '%s' for '%s'", parts[3], 
       architecture) 
      return None 
      is_quantized = True 
     data_url = 'http://download.tensorflow.org/models/mobilenet_v1_' 
     data_url += version_string + '_' + size_string + '_frozen.tgz' 
     bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0' 
     bottleneck_tensor_size = 1001 
     input_width = int(size_string) 
     input_height = int(size_string) 
     input_depth = 3 
     resized_input_tensor_name = 'input:0' 
     if is_quantized: 
      model_base_name = 'quantized_graph.pb' 
     else: 
      model_base_name = 'frozen_graph.pb' 
     model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string 
     model_file_name = os.path.join(model_dir_name, model_base_name) 
     input_mean = 127.5 
     input_std = 127.5 
     else: 
     tf.logging.error("Couldn't understand architecture name '%s'", architecture) 
     raise ValueError('Unknown architecture', architecture) 

     return { 
      'data_url': data_url, 
      'bottleneck_tensor_name': bottleneck_tensor_name, 
      'bottleneck_tensor_size': bottleneck_tensor_size, 
      'input_width': input_width, 
      'input_height': input_height, 
      'input_depth': input_depth, 
      'resized_input_tensor_name': resized_input_tensor_name, 
      'model_file_name': model_file_name, 
      'input_mean': input_mean, 
      'input_std': input_std, 
     } 

的compelete代碼可以在這裏找到: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py

+0

您必須刪除這些重塑線並重新定義inception_v3模型,然後使用一些預先訓練的模型重新訓練模型。 –

+0

謝謝你的評論@IshantMrinal。如果可能的話,你可以更詳細地解釋它嗎?這將是非常有益的。 –

回答

0

他們不添加重塑層,他們選擇從訓練模型的重塑層。之後他們將在該整形圖層的輸出之上添加自己的圖層。

如果要選擇較高層,請將「pool_3/_reshape:0」替換爲所需圖層的名稱。你應該能夠推斷出型號代碼的名稱:https://github.com/tensorflow/models/blob/master/slim/nets/inception_v3.py

或許更容易,你graph_def打印所有節點的名稱,並選擇您想要的:

for node in graph_def.node: 
     print(node.name) 
0

從SNPE​​ SDK V1 .8.0,TensorFlow的reshape層得到支持。