當Java和Python tensorflow版本均爲1.2.0時,似乎我們可以使用SavedModelBundle(Java)和Saved Model API(Python)爲了在Python tensorflow中保存一個訓練好的模型並在Java tensorflow中加載模型(不適用於Maven)。在Java tensorflow v.1.2.0中使用Python tensorflow v.0.9.0加載預訓練模型
但是,當Python的版本低於1.0時,我無法找到正確加載Java模型的方法。
我訓練了一個模型,並將其保存爲Python tensorflow(0.9.0)中的.pb,.sd和.txt文件,並在tensorflow網站中按照example指令加載模型。但是,我得到了以下錯誤:
Exception in thread "main" java.lang.IllegalStateException: Attempting
to use uninitialized value policy/mean_network/hidden_1/b
[[Node: _retval_policy/mean_network/hidden_1/b_0_0 =
_Retval[T=DT_FLOAT, index=0,
_device="/job:localhost/replica:0/task:0/cpu:0"]
(policy/mean_network/hidden_1/b)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:285)
at org.tensorflow.Session$Runner.run(Session.java:235)
at Carpole.executeGraph(Carpole.java:42)
at Carpole.main(Carpole.java:30)
有誰知道如何正確地加載Java中的預訓練模式在最新的版本(因爲我無法找到以前版本的API了),而無需使用保存的模型API?
在此先感謝!
這是救了我的Python代碼:
with tf.Session() as sess:
self.saver = tf.train.Saver(tf.all_variables())
sess.run(tf.initialize_all_variables())
…..
saver_def = self.saver.as_saver_def()
print(saver_def.filename_tensor_name)
print(saver_def.restore_op_name)
self.saver.save(sess, 'trained_model'+str(itr)+'.sd')
tf.train.write_graph(sess.graph_def, '.', 'trained_model'+str(itr)+'.pb', as_text=False)
tf.train.write_graph(sess.graph_def, '.', 'trained_model'+str(itr)+'.txt', as_text=True)
這裏是我的Java代碼
public static void main(String[] args) throws Exception {
String dataDirPath = args[0];
byte[] graphDef = readAllBytesOrExit(Paths.get(dataDirPath, "trained_model10.pb"));
List<String> labels = readAllLinesOrExit(Paths.get(dataDirPath, "trained_model10.txt"));
float[] vector = new float[4];
vector[0] = (float) -0.09341373;
vector[1] = (float) -0.07540844;
vector[2] = (float) 0.00930138;
vector[3] = (float) -0.14317159;
Tensor input = Tensor.create(vector);
float[] labelProbabilities = executeGraph(graphDef, input);
int bestLabelIdx = maxIndex(labelProbabilities);
System.out.println(String.format("BEST MATCH: %s (%.2f%% likely)",labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
}
private static float[] executeGraph(byte[] graphDef, Tensor input_tensor) {
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
System.out.println(g);
try (Session s = new Session(g); Tensor result = s.runner().feed("policy/mean_network/input/input",input_tensor).fetch("policy/mean_network/hidden_1/b").run().get(0)) {
final long[] rshape = result.shape();
}
int nlabels = (int) rshape[1];
return result.copyTo(new float[1][nlabels])[0];
}
}
謝謝!但是,不知何故,我的tensorflow沒有tensorflow/python/tools模塊(我正在使用anaconda2)。我最終升級到了v.1.1.0,稍微修改了我的代碼並使用了Saved Model API。 –