0
我想加載一個pretrained模型(使用python)到一個java項目中。tensorflow java api錯誤:java.lang.IllegalStateException:張量不是標量
問題是
Exception in thread "Thread-9" java.lang.IllegalStateException: Tensor is not a scalar
at org.tensorflow.Tensor.scalarFloat(Native Method)
at org.tensorflow.Tensor.floatValue(Tensor.java:279)
代碼
float[] arr=context.csvintarr(context.getPlayer(playerId));
float[][] martix={arr};
try (Graph g=model.graph()){
try(Session s=model.session()){
Tensor y=s.runner().feed("input/input", Tensor.create(martix))
.fetch("out/predict").run().get(0);
logger.info("a {}",y.floatValue());
}
}
的Python代碼培養和保存模型
with tf.Session() as sess:
with tf.name_scope('input'):
x=tf.placeholder(tf.float32,[None,bucketlen],name="input")
......
with tf.name_scope('out'):
y=tf.tanh(tf.matmul(h,hW)+hb,name="predict")
builder=tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess,['foo-tag'])
......after the train process
builder.save()
看來,我已成功加載模型和圖形,因爲
try (Graph g=model.graph()){
try(Session s=model.session()){
Operation operation=g.operation("input/input");
logger.info(operation.name());
}
}
成功打印出名稱。