2017-05-08 40 views
1

使用TensorFlow LinearClassifier在Python我已經訓練了TensorFlow LinearClassifier並保存它喜歡:如何在Java中

model = tf.contrib.learn.LinearClassifier(feature_columns=columns) 
model.fit(input_fn=train_input_fn, steps=100) 
model.export_savedmodel(export_dir, parsing_serving_input_fn) 

通過使用TensorFlow的Java API我可以用在Java中加載這個模型:

model = SavedModelBundle.load(export_dir, "serve"); 

看來我應該能夠運行使用的東西的圖形像

model.session().runner().feed(???, ???).fetch(???, ???).run() 

但是我應該從圖表中提取/提取哪些變量名稱/數據以提供其功能並獲取類別的概率?據我所知,Java文檔缺少這些信息。

回答

2

要饋入的節點的名稱取決於parsing_serving_input_fn的作用,特別是它們應該是由parsing_serving_input_fn返回的Tensor對象的名稱。要獲取的節點名稱取決於您預測的內容(如果使用來自Python的模型,則參數爲model.predict())。

也就是說,TensorFlow保存的模型格式確實包含模型的「簽名」(即可以提供或提取的所有Tensors的名稱)作爲可以提供提示的元數據。

從Python中可以加載保存的模型,並使用類似列出其簽名:

with tf.Session() as sess: 
    md = tf.saved_model.loader.load(sess, ['serve'], export_dir) 
    sig = md.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 
    print(sig) 

這將打印出類似這樣:

inputs { 
    key: "inputs" 
    value { 
    name: "input_example_tensor:0" 
    dtype: DT_STRING 
    tensor_shape { 
     dim { 
     size: -1 
     } 
    } 
    } 
} 
outputs { 
    key: "scores" 
    value { 
    name: "linear/binary_logistic_head/predictions/probabilities:0" 
    dtype: DT_FLOAT 
    tensor_shape { 
     dim { 
     size: -1 
     } 
     dim { 
     size: 2 
     } 
    } 
    } 
} 
method_name: "tensorflow/serving/classify" 

暗示你想用Java做什麼是:

Tensor t = /* Tensor object to be fed */ 
model.session().runner().feed("input_example_tensor", t).fetch("linear/binary_logistic_head/predictions/probabilities").run() 

您還可以純粹在Java中提取此信息如果y我們的計劃包括TensorFlow協議緩衝區生成的Java代碼(封裝在org.tensorflow:proto artifact)使用這樣的事情:

// Same as tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 
// in Python. Perhaps this should be an exported constant in TensorFlow's Java API. 
final String DEFAULT_SERVING_SIGNATURE_DEF_KEY = "serving_default"; 

final SignatureDef sig = 
     MetaGraphDef.parseFrom(model.metaGraphDef()) 
      .getSignatureDefOrThrow(DEFAULT_SERVING_SIGNATURE_DEF_KEY); 

你將不得不補充:

import org.tensorflow.framework.MetaGraphDef; 
import org.tensorflow.framework.SignatureDef; 

因爲Java API和saved-模型格式有點新,在文檔中有很大的改進空間。

希望有所幫助。

+0

感謝您的回答!這看起來有希望。但是,我必須爲input_example_tensor提供什麼?例如,考慮[TensorFlow Iris分類教程](https://www.tensorflow.org/get_started/tflearn):導出該模型會得到與您提供的相同的簽名(輸入,dtype:DT_STRING),但我需要以某種方式喂這個模型4個數字。 –

+0

現在我明白了,模型需要一個序列化的示例協議緩衝區,但是此時(1)協議緩衝區在Java中不可用,(2)使用DataType字符串創建Tensors(這是序列化示例所需的)是尚未支持。 :( –

+0

僅供參考:在[org.tensorflow:proto](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/proto)中提供了Java中的協議緩衝區maven artifact([javadoc]( http://javadoc.io/doc/org.tensorflow/proto/)) DataType.STRING張量支持標量(即,一個字符串),但不是多維數組呢(https://github.com/tensorflow/tensorflow/issues/8531) 希望有幫助。 – ash