我有一個基本Android TensorFlowInference
示例,它可以在單線程中正常運行。在多核設備上運行TensorFlow
public class InferenceExample {
private static final String MODEL_FILE = "file:///android_asset/model.pb";
private static final String INPUT_NODE = "intput_node0";
private static final String OUTPUT_NODE = "output_node0";
private static final int[] INPUT_SIZE = {1, 8000, 1};
public static final int CHUNK_SIZE = 8000;
public static final int STRIDE = 4;
private static final int NUM_OUTPUT_STATES = 5;
private static TensorFlowInferenceInterface inferenceInterface;
public InferenceExample(final Context context) {
inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), MODEL_FILE);
}
public float[] run(float[] data) {
float[] res = new float[CHUNK_SIZE/STRIDE * NUM_OUTPUT_STATES];
inferenceInterface.feed(INPUT_NODE, data, INPUT_SIZE[0], INPUT_SIZE[1], INPUT_SIZE[2]);
inferenceInterface.run(new String[]{OUTPUT_NODE});
inferenceInterface.fetch(OUTPUT_NODE, res);
return res;
}
}
的例子崩潰,各種異常,包括java.lang.ArrayIndexOutOfBoundsException
和java.lang.NullPointerException
在ThreadPool
按照下面的例子,所以我想這不是線程安全運行時。
InferenceExample inference = new InferenceExample(context);
ExecutorService executor = Executors.newFixedThreadPool(NUMBER_OF_CORES);
Collection<Future<?>> futures = new LinkedList<Future<?>>();
for (int i = 1; i <= 100; i++) {
Future<?> result = executor.submit(new Runnable() {
public void run() {
inference.call(randomData);
}
});
futures.add(result);
}
for (Future<?> future:futures) {
try { future.get(); }
catch(ExecutionException | InterruptedException e) {
Log.e("TF", e.getMessage());
}
}
是否有可能利用多核Android設備與TensorFlowInferenceInterface
?
我強烈建議不要這種方法。當然,你已經做到了這樣,可以同時調用'run',但只有當你不改變輸入時(通過調用'TensorFlowInferenceInterface.feed()')纔有意義。 假設你想要你的線程提供不同的輸入,以便計算可以在它們上面運行。你提出的方法對此並不安全。 – ash
爲什麼對於不同的輸入不安全?通過按照'id'順序將期貨存儲在循環中的細微變化,我將知道哪個輸入與哪個輸出匹配。 –
噢,對不起,我誤讀了,並沒有注意到'feed()'和'fetch()'調用在你的同步'run()'內。所以我在上面的評論中誤會了。 但是,您的方法會限制並行性,因爲這實際上是串行化使用TensorFlow會話 - 一次只能有一個線程執行模型。 – ash