使用TensorFlow LinearClassifier在Python我已經訓練了TensorFlow LinearClassifier並保存它喜歡:如何在Java中
model = tf.contrib.learn.LinearClassifier(feature_columns=columns)
model.fit(input_fn=train_input_fn, steps=100)
model.export_savedmodel(export_dir, parsing_serving_input_fn)
通過使用TensorFlow的Java API我可以用在Java中加載這個模型:
model = SavedModelBundle.load(export_dir, "serve");
看來我應該能夠運行使用的東西的圖形像
model.session().runner().feed(???, ???).fetch(???, ???).run()
但是我應該從圖表中提取/提取哪些變量名稱/數據以提供其功能並獲取類別的概率?據我所知,Java文檔缺少這些信息。
感謝您的回答!這看起來有希望。但是,我必須爲input_example_tensor提供什麼?例如,考慮[TensorFlow Iris分類教程](https://www.tensorflow.org/get_started/tflearn):導出該模型會得到與您提供的相同的簽名(輸入,dtype:DT_STRING),但我需要以某種方式喂這個模型4個數字。 –
現在我明白了,模型需要一個序列化的示例協議緩衝區,但是此時(1)協議緩衝區在Java中不可用,(2)使用DataType字符串創建Tensors(這是序列化示例所需的)是尚未支持。 :( –
僅供參考:在[org.tensorflow:proto](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/proto)中提供了Java中的協議緩衝區maven artifact([javadoc]( http://javadoc.io/doc/org.tensorflow/proto/)) DataType.STRING張量支持標量(即,一個字符串),但不是多維數組呢(https://github.com/tensorflow/tensorflow/issues/8531) 希望有幫助。 – ash