2017-04-20 60 views
3

我一直在嘗試導入和利用我的訓練模型(Tensorflow,Python)在Java中。Tensorflow模型導入到Java

我能夠用Python保存模型,但是當我嘗試使用Java中的相同模型進行預測時遇到了問題。

Here,你可以看到用於初始化,訓練和保存模型的python代碼。

Here,您可以看到用於導入和預測輸入值的Java代碼。

該錯誤消息我得到的是: Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value Variable_7 [[Node: Variable_7/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_7"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_7)]] at org.tensorflow.Session.run(Native Method) at org.tensorflow.Session.access$100(Session.java:48) at org.tensorflow.Session$Runner.runHelper(Session.java:285) at org.tensorflow.Session$Runner.run(Session.java:235) at org.tensorflow.examples.Identity_import.main(Identity_import.java:35)

我相信,這個問題是在Python代碼的某個地方,但我沒能找到它。

任何幫助表示讚賞!

謝謝

彼得

+0

我用[本](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java)作爲靈感 – szi

回答

5

Java importGraphDef()函數只導入計算g raph(在你的Python代碼中由tf.train.write_graph寫),它不會加載已訓練變量的值(存儲在檢查點中),這就是爲什麼你會抱怨未初始化變量的原因。

TensorFlow SavedModel format另一方面包含有關模型(圖形,檢查點狀態,其他元數據)的所有信息,並且您希望使用SavedModelBundle.load創建使用受過訓練的變量值初始化的會話。

要導出模型,用Python這種格式,你可能想看看一個相關的問題Deploy retrained inception SavedModel to google cloud ml engine

在你的情況,這應該達到像在Python以下幾點:

def save_model(session, input_tensor, output_tensor): 
    signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)}, 
    outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)}, 
) 
    b = saved_model_builder.SavedModelBuilder('/tmp/model') 
    b.add_meta_graph_and_variables(session, 
           [tf.saved_model.tag_constants.SERVING], 
           signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}) 
    b.save() 

並調用它通過save_model(session, x, yhat)

然後在Java中使用加載模型:

try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) { 
    // b.session().run(...) 
} 

希望有幫助。

+0

警告:雖然這將在Java中工作,但TF目前不支持在Android中加載SavedModel。發現困難的方式。 :/ – Keilaron

+0

相反,請參閱問題#12750或#13079,或者參閱https://www.tensorflow.org/mobile/prepare_models – Keilaron

1

你的Python模型在這肯定會失敗:

sess.run(init) #<---this will fail 
save_model(sess) 
error = tf.reduce_mean(tf.square(prediction - y)) 

#accuracy = tf.reduce_mean(tf.cast(error, 'float')) 
print('Error:', error) 

init未在模型中定義的 - 我不能確定你想要達到什麼這個地方,但這應該給你一個出發點

1

Fwiw,Deeplearning4j可讓您導入使用Keras 1.0(Keras 2.0支持即將推出)在TensorFlow上訓練過的模型。

https://deeplearning4j.org/model-import-keras

我們還建立了一個名爲跳動的圖書館,這大約是numpy的陣列和Pyjnius的包裝,使用指針,而不是複製數據,與張量打交道時,這使得它比Py4j更有效。

https://deeplearning4j.org/jumpy

+0

謝謝! 我是關於在Tensorflow中構建一個GAN的,就我所知,DL4J並不支持GAN,因此我一直在尋找一種解決方法,可以使用Gens訓練過的Tensorflow並在JVM上運行。 – szi