2016-11-01 103 views
2

我有一個訓練有素的凍結圖,我試圖在ARM設備上運行。基本上,我使用的是contrib/pi_examples/label_image,但用我的網絡而不是Inception。我的網絡與輟學,現在使我煩惱的培訓:從TensorFlow圖中消除丟失操作

Invalid argument: No OpKernel was registered to support Op 'Switch' with these attrs. Registered kernels: 
    device='CPU'; T in [DT_FLOAT] 
    device='CPU'; T in [DT_INT32] 
    device='GPU'; T in [DT_STRING] 
    device='GPU'; T in [DT_BOOL] 
    device='GPU'; T in [DT_INT32] 
    device='GPU'; T in [DT_FLOAT] 

[[Node: l_fc1_dropout/cond/Switch = Switch[T=DT_BOOL](is_training_pl, is_training_pl)]] 

一個解決方案,我可以看到的是構建包括相應的操作,TF靜態庫。另一方面,爲了使網絡更簡單,更快速,消除網絡中的丟包操作可能會更好。有沒有辦法做到這一點?

謝謝。

+0

你可以在文本編輯器中編輯'graph.pbtxt',擺脫差(即更換標識運差運算) –

回答

3
#!/usr/bin/env python2 

import argparse 

import tensorflow as tf 
from google.protobuf import text_format 
from tensorflow.core.framework import graph_pb2 
from tensorflow.core.framework import node_def_pb2 

def print_graph(input_graph): 
    for node in input_graph.node: 
     print "{0} : {1} ({2})".format(node.name, node.op, node.input) 

def strip(input_graph, drop_scope, input_before, output_after, pl_name): 
    input_nodes = input_graph.node 
    nodes_after_strip = [] 
    for node in input_nodes: 
     print "{0} : {1} ({2})".format(node.name, node.op, node.input) 

     if node.name.startswith(drop_scope + '/'): 
      continue 

     if node.name == pl_name: 
      continue 

     new_node = node_def_pb2.NodeDef() 
     new_node.CopyFrom(node) 
     if new_node.name == output_after: 
      new_input = [] 
      for node_name in new_node.input: 
       if node_name == drop_scope + '/cond/Merge': 
        new_input.append(input_before) 
       else: 
        new_input.append(node_name) 
      del new_node.input[:] 
      new_node.input.extend(new_input) 
     nodes_after_strip.append(new_node) 

    output_graph = graph_pb2.GraphDef() 
    output_graph.node.extend(nodes_after_strip) 
    return output_graph 

def main(): 

    parser = argparse.ArgumentParser() 
    parser.add_argument('--input-graph', action='store', dest='input_graph') 
    parser.add_argument('--input-binary', action='store_true', default=True, dest='input_binary') 
    parser.add_argument('--output-graph', action='store', dest='output_graph') 
    parser.add_argument('--output-binary', action='store_true', dest='output_binary', default=True) 

    args = parser.parse_args() 

    input_graph = args.input_graph 
    input_binary = args.input_binary 
    output_graph = args.output_graph 
    output_binary = args.output_binary 

    if not tf.gfile.Exists(input_graph): 
     print("Input graph file '" + input_graph + "' does not exist!") 
     return 

    input_graph_def = tf.GraphDef() 
    mode = "rb" if input_binary else "r" 
    with tf.gfile.FastGFile(input_graph, mode) as f: 
     if input_binary: 
      input_graph_def.ParseFromString(f.read()) 
     else: 
      text_format.Merge(f.read().decode("utf-8"), input_graph_def) 

    print "Before:" 
    print_graph(input_graph_def) 
    output_graph_def = strip(input_graph_def, u'l_fc1_dropout', u'l_fc1/Relu', u'prediction/MatMul', u'is_training_pl') 
    print "After:" 
    print_graph(output_graph_def) 

    if output_binary: 
     with tf.gfile.GFile(output_graph, "wb") as f: 
      f.write(output_graph_def.SerializeToString()) 
    else: 
     with tf.gfile.GFile(output_graph, "w") as f: 
      f.write(text_format.MessageToString(output_graph_def)) 
    print("%d ops in the final graph." % len(output_graph_def.node)) 


if __name__ == "__main__": 
    main() 
+0

的該腳本似乎刪除了圖層,但是如果我刪除了中間缺失圖層,則下一個圖層需要輸出張量。在我的情況下,當我試圖讀取圖中留下的圖層時,我收到一個錯誤:ValueError:graph_def在節點u'fc7/Conv2D'處無效:在graph_def中找不到輸入張量'dropout/mul_1:0'。 。如何在我的protobuf中更改u'fc7/Conv2D的輸入層張量名稱? –

+0

腳本提供的功能也...效果很好,謝謝。 –

3

這個怎麼樣作爲一個更通用的解決方案:

for node in temp_graph_def.node: 
    for idx, i in enumerate(node.input): 
     input_clean = node_name_from_input(i) 
     if input_clean.endswith('/cond/Merge') and input_clean.split('/')[-3].startswith('dropout'): 
      identity = node_from_map(input_node_map, i).input[0] 
      assert identity.split('/')[-1] == 'Identity' 
      parent = node_from_map(input_node_map, node_from_map(input_node_map, identity).input[0]) 
      pred_id = parent.input[1] 
      assert pred_id.split('/')[-1] == 'pred_id'    
      good = parent.input[0] 
      node.input[idx] = good