2017-07-04 101 views
0

當Java和Python tensorflow版本均爲1.2.0時,似乎我們可以使用SavedModelBundle(Java)和Saved Model API(Python)爲了在Python tensorflow中保存一個訓練好的模型並在Java tensorflow中加載模型(不適用於Maven)。在Java tensorflow v.1.2.0中使用Python tensorflow v.0.9.0加載預訓練模型

但是,當Python的版本低於1.0時,我無法找到正確加載Java模型的方法。

我訓練了一個模型,並將其保存爲Python tensorflow(0.9.0)中的.pb,.sd和.txt文件,並在tensorflow網站中按照example指令加載模型。但是,我得到了以下錯誤:

Exception in thread "main" java.lang.IllegalStateException: Attempting 
to use uninitialized value policy/mean_network/hidden_1/b 
      [[Node: _retval_policy/mean_network/hidden_1/b_0_0 = 
_Retval[T=DT_FLOAT, index=0, 
_device="/job:localhost/replica:0/task:0/cpu:0"] 
(policy/mean_network/hidden_1/b)]] 
      at org.tensorflow.Session.run(Native Method) 
      at org.tensorflow.Session.access$100(Session.java:48) 
      at org.tensorflow.Session$Runner.runHelper(Session.java:285) 
      at org.tensorflow.Session$Runner.run(Session.java:235) 
      at Carpole.executeGraph(Carpole.java:42) 
      at Carpole.main(Carpole.java:30) 

有誰知道如何正確地加載Java中的預訓練模式在最新的版本(因爲我無法找到以前版本的API了),而無需使用保存的模型API?

在此先感謝!

這是救了我的Python代碼:

with tf.Session() as sess: 
    self.saver = tf.train.Saver(tf.all_variables()) 
    sess.run(tf.initialize_all_variables()) 
    ….. 
    saver_def = self.saver.as_saver_def() 
    print(saver_def.filename_tensor_name) 
    print(saver_def.restore_op_name) 

    self.saver.save(sess, 'trained_model'+str(itr)+'.sd') 
    tf.train.write_graph(sess.graph_def, '.', 'trained_model'+str(itr)+'.pb', as_text=False) 
    tf.train.write_graph(sess.graph_def, '.', 'trained_model'+str(itr)+'.txt', as_text=True) 

這裏是我的Java代碼

public static void main(String[] args) throws Exception { 
    String dataDirPath = args[0]; 
    byte[] graphDef = readAllBytesOrExit(Paths.get(dataDirPath, "trained_model10.pb")); 
    List<String> labels = readAllLinesOrExit(Paths.get(dataDirPath, "trained_model10.txt")); 
    float[] vector = new float[4]; 
    vector[0] = (float) -0.09341373; 
    vector[1] = (float) -0.07540844; 
    vector[2] = (float) 0.00930138; 
    vector[3] = (float) -0.14317159; 
    Tensor input = Tensor.create(vector); 

    float[] labelProbabilities = executeGraph(graphDef, input); 
    int bestLabelIdx = maxIndex(labelProbabilities); 
    System.out.println(String.format("BEST MATCH: %s (%.2f%% likely)",labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f)); 
} 

private static float[] executeGraph(byte[] graphDef, Tensor input_tensor) { 
    try (Graph g = new Graph()) { 
     g.importGraphDef(graphDef); 
     System.out.println(g); 
     try (Session s = new Session(g); Tensor result = s.runner().feed("policy/mean_network/input/input",input_tensor).fetch("policy/mean_network/hidden_1/b").run().get(0)) { 
     final long[] rshape = result.shape(); 
     } 
     int nlabels = (int) rshape[1]; 
     return result.copyTo(new float[1][nlabels])[0]; 
    } 
} 

回答

0

一般來說,TensorFlow之前的版本1.0,不能保證與TensorFlow版本一起使用> = 1.0(根據TensorFlow Version Semantics和使用語義版本)。這就是說,看着你提供的代碼片斷,你正在用Java加載計算圖表,但沒有加載保存的變量,因此你會得到一個異常,抱怨變量沒有被初始化。

將圖形和保存的變量封裝到一個包中是SavedModel格式的用途。然而,如果你不能使用它,並且如果你只需要該圖形用於Java中的推理,那麼你可能想考慮「凍結」圖形,然後用Java加載。凍結圖將包含單個文件中的所有變量值。

您可以嘗試使用freeze_graph庫(來自0.9發行版分支)來保存這樣的圖。

使用TensorFlow的1.0之前的版本可能是一個挑戰。如果可能,我強烈建議您將模型移動到TensorFlow版本> = 1.0,然後利用API穩定性保證。

希望有所幫助。

+0

謝謝!但是,不知何故,我的tensorflow沒有tensorflow/python/tools模塊(我正在使用anaconda2)。我最終升級到了v.1.1.0,稍微修改了我的代碼並使用了Saved Model API。 –