我試了一段時間才能獲得在android上工作的預訓練模型。問題是,我只獲得了預訓練網絡的ckpt和meta文件。在我看來,我需要安卓應用的.pb。所以我試圖將給定的文件轉換爲.pb文件。從開放圖像數據集獲取預訓練初始v3模型在Android上工作
因此我嘗試了freeze_graph.py但沒有成功。所以我使用了https://github.com/openimages/dataset/blob/master/tools/classify.py的示例代碼並對其進行了修改以存儲pb。加載
if not os.path.exists(FLAGS.checkpoint):
tf.logging.fatal(
'Checkpoint %s does not exist. Have you download it? See tools/download_data.sh',
FLAGS.checkpoint)
g = tf.Graph()
with g.as_default():
input_image = tf.placeholder(tf.string)
processed_image = PreprocessImage(input_image)
with slim.arg_scope(inception.inception_v3_arg_scope()):
logits, end_points = inception.inception_v3(
processed_image, num_classes=FLAGS.num_classes, is_training=False)
predictions = end_points['multi_predictions'] = tf.nn.sigmoid(
logits, name='multi_predictions')
init_op = control_flow_ops.group(tf.global_variables_initializer(),
tf.global_variables_initializer(),
data_flow_ops.initialize_all_tables())
saver = tf_saver.Saver()
sess = tf.Session()
saver.restore(sess, FLAGS.checkpoint)
outpt_filename = 'output_graph.pb'
#output_graph_def = sess.graph.as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), ["multi_predictions"])
with gfile.FastGFile(outpt_filename, 'wb') as f:
f.write(output_graph_def.SerializeToString())
現在我的問題後,文件是我有.pb文件,但我沒有任何意見什麼是輸入節點名,我不知道如果multi_predictions
是正確的輸出名稱。在示例Android應用程序中,我必須指定兩者。和Android應用程序墜毀:
tensorflow_inference_jni.cc:138 Could not create Tensorflow Graph: Invalid argument: No OpKernel was registered to support Op 'DecodeJpeg' with these attrs.
我不知道是否有更多的問題,試圖解決.pb問題。或者如果有人知道更好的方法將ckpt和元文件移植到我的.pd文件中,或者知道輸入和輸出名稱的最終文件的來源,請給我一個提示以完成此任務。
感謝
對不起忙了一陣子。對於使用optimize_for_inference.py,您的提示絕對正確,我可以使用「Mul」作爲輸入節點。非常感謝 – lampep
嗨@lampep你如何預處理你的圖像?我試圖在ios上運行此操作,並重復相同的預測。我遵循你的代碼,生成了一個優化的圖表,但預測並不正確。謝謝! –