0
我培養了一個佔位符模型is_training
:永久注入恆成Tensorflow圖形的推理
is_training_ph = tf.placeholder(tf.bool)
然而,一旦訓練和驗證完成後,我想在這個值永久注入的false
恆定然後「重新優化」圖(即使用optimize_for_inference
)。請問freeze_graph
會這樣做嗎?
我培養了一個佔位符模型is_training
:永久注入恆成Tensorflow圖形的推理
is_training_ph = tf.placeholder(tf.bool)
然而,一旦訓練和驗證完成後,我想在這個值永久注入的false
恆定然後「重新優化」圖(即使用optimize_for_inference
)。請問freeze_graph
會這樣做嗎?
一種可能性是使用tf.import_graph_def()
函數及其input_map
參數來重寫圖中張量的值。例如,可以按如下方式構建程序:
with tf.Graph().as_default() as training_graph:
# Build model.
is_training_ph = tf.placeholder(tf.bool, name="is_training")
# ...
training_graph_def = training_graph.as_graph_def()
with tf.Graph().as_default() as temp_graph:
tf.import_graph_def(training_graph_def,
input_map={is_training_ph.name: tf.constant(False)})
temp_graph_def = temp_graph.as_graph_def()
建設temp_graph_def
後,您可以使用它作爲輸入freeze_graph
。
的替代,這可能是與freeze_graph
和optimize_for_inference
腳本(這使有關變量名和檢查點鍵假設)更兼容是修改TensorFlow的graph_util.convert_variables_to_constants()
功能,使其轉化佔位符代替:
def convert_placeholders_to_constants(input_graph_def,
placeholder_to_value_map):
"""Replaces placeholders in the given tf.GraphDef with constant values.
Args:
input_graph_def: GraphDef object holding the network.
placeholder_to_value_map: A map from the names of placeholder tensors in
`input_graph_def` to constant values.
Returns:
GraphDef containing a simplified version of the original.
"""
output_graph_def = tf.GraphDef()
for node in input_graph_def.node:
output_node = tf.NodeDef()
if node.op == "Placeholder" and node.name in placeholder_to_value_map:
output_node.op = "Const"
output_node.name = node.name
dtype = node.attr["dtype"].type
data = np.asarray(placeholder_to_value_map[node.name],
dtype=tf.as_dtype(dtype).as_numpy_dtype)
output_node.attr["dtype"].type = dtype
output_node.attr["value"].CopyFrom(tf.AttrValue(
tensor=tf.contrib.util.make_tensor_proto(data,
dtype=dtype,
shape=data.shape)))
else:
output_node.CopyFrom(node)
output_graph_def.node.extend([output_node])
return output_graph_def
...那麼你可以建立training_graph_def
如上,並寫:
temp_graph_def = convert_placeholders_to_constants(training_graph_def,
{is_training_ph.op.name: False})
是否牛逼帽子錯誤來自'freeze_graph'?如果是這樣,這些腳本可能有點太有限,無法處理這種情況。我在這種情況下給這個答案增加了一個建議。 – mrry
我不這麼認爲......除非我犯了一個錯誤,它應該用'placeholder_to_value_map'中的各個值替換其值爲'placeholder_to_value_map'的鍵名的所有佔位符。 (注意:我沒有測試過。) – mrry
好的,謝謝! – mrry