2017-06-20 33 views
0

我試圖執行上,我使用「AlexNet細化和微調與TensorFlow」 https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html也不會在Java API運行Tensorflow預測

我保存在使用Python tf.saved_model.builder.SavedModelBuilder模型訓練的模型預測,並加載Java中的模型使用SavedModelBundle.load。 代碼的主要部分是:

SavedModelBundle smb = SavedModelBundle.load(path, "serve"); 
    Session s = smb.session(); 
    byte[] imageBytes = readAllBytesOrExit(Paths.get(path)); 
    Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes); 
    Tensor result = s.runner().feed("input_tensor", image).fetch("fc8/fc8").run().get(0); 
    final long[] rshape = result.shape(); 
    if (result.numDimensions() != 2 || rshape[0] != 1) { 
     throw new RuntimeException(
       String.format(
         "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", 
         Arrays.toString(rshape))); 
    } 
    int nlabels = (int) rshape[1]; 
    float [] a = result.copyTo(new float[1][nlabels])[0];` 

我得到這個異常:在線程

異常「主」 java.lang.IllegalArgumentException異常:您必須養活一個值佔位張量「 Placeholder_1'用dtype float [[Node:Placeholder_1 = Placeholder_output_shapes = [[]],dtype = DT_FLOAT,shape = [],_device =「/ job:localhost/replica:0/task:0/cpu:0」]]

我看到上面的代碼爲某些人工作,我無法弄清楚這裏缺少的東西。 請注意,該網絡熟悉節點「input_tensor」和「fc8/fc8」,因爲它沒有說它不知道它們。

回答

1

從錯誤消息中可以看出,您使用的模型期望得到另一個值(圖中的節點名稱爲Placeholder_1,預期類型爲浮點標量張量)。

看起來你已經定製了你的模型(而不是跟隨你鏈接到逐字的文章)。也就是說,文章顯示需要餵食的多個佔位符,一個用於圖像,另一個用於控制脫落。在文章中定義爲:

keep_prob = tf.placeholder(tf.float32) 

此佔位符的值需要提供。如果您正在進行推理,那麼您想將keep_prob設置爲1.0。類似於:

Tensor keep_prob = Tensor.create(1.0f); 
Tensor result = s.runner() 
    .feed("input_tensor", image) 
    .feed("Placeholder_1", keep_prob) 
    .fetch("fc8/fc8") 
    .run() 
    .get(0); 

希望有所幫助。