1

我已經看了幾篇文章,在stackoverflow和已經在它幾天了,但唉,我無法通過張量流服務正確服務對象檢測模型。Tensorflow服務已保存的模型ssd_mobilenet_v1_coco

我訪問以下鏈接: How to properly serve an object detection model from Tensorflow Object Detection API?

https://github.com/tensorflow/tensorflow/issues/11863

這就是我所做的。

我已經下載了ssd_mobilenet_v1_coco_11_06_2017.tar.gz,其中包含了以下文件:

frozen_inference_graph.pb 
graph.pbtxt 
model.ckpt.data-00000-of-00001 
model.ckpt.index 
model.ckpt.meta 

使用下面的腳本,我是能夠成功的frozen_inference_graph.pb轉換爲SavedModel(下目錄ssd_mobilenet_v1_coco_11_06_2017 /保存)

import tensorflow as tf 
from tensorflow.python.saved_model import signature_constants 
from tensorflow.python.saved_model import tag_constants 
import ipdb 

# Specify version 1 
export_dir = './saved/1' 
graph_pb = 'frozen_inference_graph.pb' 

builder = tf.saved_model.builder.SavedModelBuilder(export_dir) 

with tf.gfile.GFile(graph_pb, "rb") as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 

sigs = {} 

with tf.Session(graph=tf.Graph()) as sess: 
    # name="" is important to ensure we don't get spurious prefixing 
    tf.import_graph_def(graph_def, name="") 
    g = tf.get_default_graph() 
    ipdb.set_trace() 
    inp = g.get_tensor_by_name("image_tensor:0") 
    outputs = {} 
    outputs["detection_boxes"] = g.get_tensor_by_name('detection_boxes:0') 
    outputs["detection_scores"] = g.get_tensor_by_name('detection_scores:0') 
    outputs["detection_classes"] = g.get_tensor_by_name('detection_classes:0') 
    outputs["num_detections"] = g.get_tensor_by_name('num_detections:0') 

    output_tensor = tf.concat([tf.expand_dims(t, 0) for t in outputs], 0) 
    # or use tf.gather?? 

    # out = g.get_tensor_by_name("generator/Tanh:0") 

    sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \ 
     tf.saved_model.signature_def_utils.predict_signature_def(
      {"in": inp}, {"out": output_tensor}) 

    sigs["predict_images"] = \ 
     tf.saved_model.signature_def_utils.predict_signature_def(
      {"in": inp}, {"out": output_tensor}) 

    builder.add_meta_graph_and_variables(sess, 
             [tag_constants.SERVING], 
             signature_def_map=sigs) 

builder.save() 

我得到以下錯誤:

bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server 
--port=9000 --model_base_path=/serving/ssd_mobilenet_v1_coco_11_06_2017/saved 

2017-09-17 22:33:21.325087: W tensorflow_serving/sources/storage_path/file_system_storage_path_source.cc:268] No versions of servable default found under base path /serving/ssd_mobilenet_v1_coco_11_06_2017/saved/1 

我知道我需要一個客戶端連接到服務器來做預測。但是,我甚至無法正確地爲模型服務。

回答

3

您需要根據原始文章做些什麼來改變導出簽名。該腳本爲您做必要的修改:

$OBJECT_DETECTION_CONFIG=object_detection/samples/configs/ssd_mobilenet_v1_pets.config 

$ python object_detection/export_inference_graph.py \ --input_type encoded_image_string_tensor \ --pipeline_config_path ${OBJECT_DETECTION_CONFIG} \ --trained_checkpoint_prefix ${YOUR_LOCAL_CHK_DIR}/model.ckpt-${CHECKPOINT_NUMBER} \ --output_directory ${YOUR_LOCAL_EXPORT_DIR} 

有關什麼程序做的詳細信息,請參閱:

https://cloud.google.com/blog/big-data/2017/09/performing-prediction-with-tensorflow-object-detection-models-on-google-cloud-machine-learning-engine

相關問題