2016-11-28 30 views

回答

2

一種可能性是使用tf.import_graph_def()函數及其input_map參數來重寫圖中張量的值。例如,可以按如下方式構建程序:

with tf.Graph().as_default() as training_graph: 
    # Build model. 
    is_training_ph = tf.placeholder(tf.bool, name="is_training") 
    # ... 
training_graph_def = training_graph.as_graph_def() 

with tf.Graph().as_default() as temp_graph: 
    tf.import_graph_def(training_graph_def, 
         input_map={is_training_ph.name: tf.constant(False)}) 
temp_graph_def = temp_graph.as_graph_def() 

建設temp_graph_def後,您可以使用它作爲輸入freeze_graph


的替代,這可能是與freeze_graphoptimize_for_inference腳本(這使有關變量名和檢查點鍵假設)更兼容是修改TensorFlow的graph_util.convert_variables_to_constants()功能,使其轉化佔位符代替:

def convert_placeholders_to_constants(input_graph_def, 
             placeholder_to_value_map): 
    """Replaces placeholders in the given tf.GraphDef with constant values. 

    Args: 
    input_graph_def: GraphDef object holding the network. 
    placeholder_to_value_map: A map from the names of placeholder tensors in 
     `input_graph_def` to constant values. 

    Returns: 
    GraphDef containing a simplified version of the original. 
    """ 

    output_graph_def = tf.GraphDef() 

    for node in input_graph_def.node: 
    output_node = tf.NodeDef() 
    if node.op == "Placeholder" and node.name in placeholder_to_value_map: 
     output_node.op = "Const" 
     output_node.name = node.name 
     dtype = node.attr["dtype"].type 
     data = np.asarray(placeholder_to_value_map[node.name], 
         dtype=tf.as_dtype(dtype).as_numpy_dtype) 
     output_node.attr["dtype"].type = dtype 
     output_node.attr["value"].CopyFrom(tf.AttrValue(
      tensor=tf.contrib.util.make_tensor_proto(data, 
                dtype=dtype, 
                shape=data.shape))) 
    else: 
     output_node.CopyFrom(node) 

    output_graph_def.node.extend([output_node]) 

    return output_graph_def 

...那麼你可以建立training_graph_def如上,並寫:

temp_graph_def = convert_placeholders_to_constants(training_graph_def, 
                {is_training_ph.op.name: False}) 
+0

是否牛逼帽子錯誤來自'freeze_graph'?如果是這樣,這些腳本可能有點太有限,無法處理這種情況。我在這種情況下給這個答案增加了一個建議。 – mrry

+0

我不這麼認爲......除非我犯了一個錯誤,它應該用'placeholder_to_value_map'中的各個值替換其值爲'placeholder_to_value_map'的鍵名的所有佔位符。 (注意:我沒有測試過。) – mrry

+0

好的,謝謝! – mrry