我試圖執行上,我使用「AlexNet細化和微調與TensorFlow」 https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html也不會在Java API運行Tensorflow預測
我保存在使用Python tf.saved_model.builder.SavedModelBuilder
模型訓練的模型預測,並加載Java中的模型使用SavedModelBundle.load
。 代碼的主要部分是:
SavedModelBundle smb = SavedModelBundle.load(path, "serve");
Session s = smb.session();
byte[] imageBytes = readAllBytesOrExit(Paths.get(path));
Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes);
Tensor result = s.runner().feed("input_tensor", image).fetch("fc8/fc8").run().get(0);
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
float [] a = result.copyTo(new float[1][nlabels])[0];`
我得到這個異常:在線程
異常「主」 java.lang.IllegalArgumentException異常:您必須養活一個值佔位張量「 Placeholder_1'用dtype float [[Node:Placeholder_1 = Placeholder_output_shapes = [[]],dtype = DT_FLOAT,shape = [],_device =「/ job:localhost/replica:0/task:0/cpu:0」]]
我看到上面的代碼爲某些人工作,我無法弄清楚這裏缺少的東西。 請注意,該網絡熟悉節點「input_tensor」和「fc8/fc8」,因爲它沒有說它不知道它們。