2017-08-20 55 views
0

我想加載一個pretrained模型(使用python)到一個java項目中。tensorflow java api錯誤:java.lang.IllegalStateException:張量不是標量

問題是

Exception in thread "Thread-9" java.lang.IllegalStateException: Tensor is not a scalar 
    at org.tensorflow.Tensor.scalarFloat(Native Method) 
    at org.tensorflow.Tensor.floatValue(Tensor.java:279) 

代碼

float[] arr=context.csvintarr(context.getPlayer(playerId)); 
    float[][] martix={arr}; 
    try (Graph g=model.graph()){ 
     try(Session s=model.session()){ 

      Tensor y=s.runner().feed("input/input", Tensor.create(martix)) 
      .fetch("out/predict").run().get(0); 
      logger.info("a {}",y.floatValue()); 
     } 
    } 

的Python代碼培養和保存模型

with tf.Session() as sess: 
    with tf.name_scope('input'): 
     x=tf.placeholder(tf.float32,[None,bucketlen],name="input") 
...... 
    with tf.name_scope('out'): 
     y=tf.tanh(tf.matmul(h,hW)+hb,name="predict") 
    builder=tf.saved_model.builder.SavedModelBuilder(export_dir) 
    builder.add_meta_graph_and_variables(sess,['foo-tag']) 
......after the train process 

    builder.save() 

看來,我已成功加載模型和圖形,因爲

try (Graph g=model.graph()){ 
     try(Session s=model.session()){ 
      Operation operation=g.operation("input/input"); 
      logger.info(operation.name()); 

     } 
    } 

成功打印出名稱。

回答

1

錯誤消息表明輸出張量不是浮點值標量,所以它可能是更高維張量(矢量,矩陣)。

您可以使用System.out.println(y.toString())或專門使用y.shape()來了解張量的形狀。在你的Python代碼中,這將對應於y.shape

對於非標量,使用y.copyTo獲得float數組(用於向量),或陣列浮標陣列(爲一個矩陣)等

例如,是這樣的:

System.out.println(y); 
// If the above printed something like: 
// "FLOAT tensor with shape [1]" 
// then you can get the values using: 
float[] vector = y.copyTo(new float[1]); 

// If the shape was something like [2, 3] 
// then you can get the values using: 
float[][] matrix = y.copyTo(new float[2][3]); 

查看Tensor javadoc瞭解更多關於floatValue() vs copyTo vs writeTo的資訊。

希望有所幫助。