2016-11-08 51 views
0

我有這樣的代碼:爲什麼這個JNI程序不能複製float值回Java端?

#if defined(NOT_STANDALONE) 
JNIEXPORT void JNICALL sumTraces 
    (JNIEnv* env, jclass caller, jobjectArray jprestackTraces, jint nTracesIn, jobjectArray jsampleShifts, 
    jobjectArray jstartIndices, jobjectArray jnSamples, jobjectArray jstackTracesOut, 
    jobjectArray jpowerTracesOut, jint nTracesOut, jint samplesPerTrace) { 

    jboolean isCopy; 

    float* prestackTraces1D = (float*)malloc(nTracesIn * samplesPerTrace * sizeof(float)); 
    if (prestackTraces1D == NULL) Fatal("Could not malloc prestackTraces1D"); 
    int* sampleShifts1D = (int*)malloc(nTracesIn * nTracesOut * sizeof(int)); 
    if (sampleShifts1D == NULL) Fatal("Could not malloc sampleShifts1D"); 
    int* startIndices1D = (int*)malloc(nTracesIn * nTracesOut * sizeof(int)); 
    if (startIndices1D == NULL) Fatal("Could not malloc startIndices1D"); 
    int* nSamples1D = (int*)malloc(nTracesIn * nTracesOut * sizeof(int)); 
    if (nSamples1D == NULL) Fatal("Could not malloc nSamples1D"); 

    for (int in = 0; in < nTracesIn; in++) { 

    jfloatArray j_prestack = (jfloatArray)env->GetObjectArrayElement(jprestackTraces, in); 
    float* prestackTracesJava = (float*)env->GetPrimitiveArrayCritical(j_prestack, &isCopy); 

    for (int s = 0; s < samplesPerTrace; s++) { 
     int readIndex = s + (in * samplesPerTrace); 
     prestackTraces1D[readIndex] = prestackTracesJava[s]; 
    } 

    env->ReleasePrimitiveArrayCritical(j_prestack, prestackTracesJava, JNI_ABORT); 
    } 

    for (int out = 0; out < nTracesOut; out++) { 

    jintArray j_shift = (jintArray)env->GetObjectArrayElement(jsampleShifts, out); 
    int* sampleShiftsJava = (int*)env->GetPrimitiveArrayCritical(j_shift, &isCopy); 
    jintArray j_start = (jintArray)env->GetObjectArrayElement(jstartIndices, out); 
    int* startIndicesJava = (int*)env->GetPrimitiveArrayCritical(j_start, &isCopy); 
    jintArray j_nSamps = (jintArray)env->GetObjectArrayElement(jnSamples, out); 
    int* nSamplesJava = (int*)env->GetPrimitiveArrayCritical(j_nSamps, &isCopy); 

    for (int in = 0; in < nTracesIn; in++) { 
     int readIndex = in + (out * nTracesIn); 
     sampleShifts1D[readIndex] = sampleShiftsJava[in]; 
     startIndices1D[readIndex] = startIndicesJava[in]; 
     nSamples1D[readIndex] = nSamplesJava[in]; 
    } 

    env->ReleasePrimitiveArrayCritical(j_nSamps, nSamplesJava, JNI_ABORT); 
    env->ReleasePrimitiveArrayCritical(j_start, startIndicesJava, JNI_ABORT); 
    env->ReleasePrimitiveArrayCritical(j_shift, sampleShiftsJava, JNI_ABORT); 
    } 

    float* stackTracesOut1D = (float*)malloc(nTracesOut * samplesPerTrace * sizeof(float)); 
    if (stackTracesOut1D == NULL) Fatal("Could not malloc stackTracesOut1D"); 
    float* powerTracesOut1D = (float*)malloc(nTracesOut * samplesPerTrace * sizeof(float)); 
    if (powerTracesOut1D == NULL) Fatal("Could not malloc powerTracesOut1D"); 

    // Run the OpenCL program 
    ComputeTraces(prestackTraces1D, stackTracesOut1D, powerTracesOut1D, 
    startIndices1D, nSamples1D, sampleShifts1D, 
    samplesPerTrace, nTracesIn, nTracesOut, 
    0, 0, 1000); 

    // Free the arrays that we can 
    free(nSamples1D); 
    free(startIndices1D); 
    free(sampleShifts1D); 
    free(prestackTraces1D); 

    // Copy back the output for Java 
    for (int out = 0; out < nTracesOut; out++) { 
    jfloatArray j_stackOut = (jfloatArray)env->GetObjectArrayElement(jstackTracesOut, out); 
    jfloatArray j_powerOut = (jfloatArray)env->GetObjectArrayElement(jpowerTracesOut, out); 

    float* stackOutCopyBack = (float*)malloc(samplesPerTrace * sizeof(float)); 
    float* powerOutCopyBack = (float*)malloc(samplesPerTrace * sizeof(float)); 
    for (int s = 0; s < samplesPerTrace; s++) { 
     int readIndex = s + (out * samplesPerTrace); 
     stackOutCopyBack[s] = stackTracesOut1D[readIndex]; 
     powerOutCopyBack[s] = powerTracesOut1D[readIndex]; 
    } 

    for (int s = 0; s < samplesPerTrace; s++) { 
     printf("%d %f/%f\n", s, stackOutCopyBack[s], powerOutCopyBack[s]); 
    } 

    env->ReleasePrimitiveArrayCritical(j_stackOut, stackOutCopyBack, 0); 
    env->ReleasePrimitiveArrayCritical(j_powerOut, powerOutCopyBack, 0); 

    free(stackOutCopyBack); 
    free(powerOutCopyBack); 
    } 

    // Free the output arrays 
    free(powerTracesOut1D); 
    free(stackTracesOut1D); 
} 

的ComputeTraces(...)方法填充stackTracesOut1D和powerTracesOut1D陣列與值。我知道這些值是正確的,因爲在接近尾聲的for循環中有printf語句,我將它與我想要的值和它們匹配的值進行比較。但是,當我檢查Java端時,所有的值都被清零。爲什麼這個JNI代碼不能複製數據?

請記住,正如您在代碼中所看到的那樣,我必須將通過參數給出的2D數組濃縮到一維數組中,才能傳入函數。因此,在複製數據之前,我將更大的1D數組的一部分複製到較小的數組中,這是ReleasePrimitiveArrayCritical中的參數之一,但是這些值不會被複制回來。

編輯:只是要說清楚,我正在談論從最後10行起的線; env->ReleasePrimitiveArrayCritical(j_stackOut, stackOutCopyBack, 0);其中我使用的是0.

+1

我沒有看到你在哪裏複製stackTracesOut1D或powerTracesOut1D到一個Java對象。你認爲它應該在哪裏複製? –

+0

在最後一個嵌入for循環內部(除了print語句),我在其中複製了一部分數據的行'stackOutCopyBack [s] = stackTracesOut1D [readIndex];''和'powerOutCopyBack [s] = powerTracesOut1D [readIndex]將這些數組添加到新創建的更小的數組中,該數組僅用於複製。 – danglingPointer

+0

但是stackOutCopyBack和powerTracesOut1D沒有通過GetPrimitiveArrayCritical獲得。當指針沒有被獲得時,JNI文檔不清楚這裏發生了什麼。事實上,mode = 0似乎是錯誤的,因爲ReleasePrimitiveArrayCritical可能釋放指針。有沒有理由不重寫它使用GetPrimitiveArrayCritical? –

回答

1

所以問題只是我忘了在輸出數組上使用GetPrimitiveArrayCritical(...)。所以:

for (int out = 0; out < nTracesOut; out++) { 
    jfloatArray j_stackOut = (jfloatArray)env->GetObjectArrayElement(jstackTracesOut, out); 
    jfloatArray j_powerOut = (jfloatArray)env->GetObjectArrayElement(jpowerTracesOut, out); 

    float* stackOutCopyBack = (float*)malloc(samplesPerTrace * sizeof(float)); 
    float* powerOutCopyBack = (float*)malloc(samplesPerTrace * sizeof(float)); 

    for (int s = 0; s < samplesPerTrace; s++) { 
     int readIndex = s + (out * samplesPerTrace); 
     stackOutCopyBack[s] = stackTracesOut1D[readIndex]; 
     powerOutCopyBack[s] = powerTracesOut1D[readIndex]; 
    } 

    env->ReleasePrimitiveArrayCritical(j_stackOut, stackOutCopyBack, 0); 
    env->ReleasePrimitiveArrayCritical(j_powerOut, powerOutCopyBack, 0); 

    free(stackOutCopyBack); 
    free(powerOutCopyBack); 
    } 

變爲:

for (int out = 0; out < nTracesOut; out++) { 
    jfloatArray j_stackOut = (jfloatArray)env->GetObjectArrayElement(jstackTracesOut, out); 
    jfloatArray j_powerOut = (jfloatArray)env->GetObjectArrayElement(jpowerTracesOut, out); 

    float* stackOutCopyBack = (float*)env->GetPrimitiveArrayCritical(j_stackOut, &isCopy); 
    float* powerOutCopyBack = (float*)env->GetPrimitiveArrayCritical(j_powerOut, &isCopy); 

    for (int s = 0; s < samplesPerTrace; s++) { 
     int readIndex = s + (out * samplesPerTrace); 
     stackOutCopyBack[s] = stackTracesOut1D[readIndex]; 
     powerOutCopyBack[s] = powerTracesOut1D[readIndex]; 
    } 

    env->ReleasePrimitiveArrayCritical(j_stackOut, stackOutCopyBack, 0); 
    env->ReleasePrimitiveArrayCritical(j_powerOut, powerOutCopyBack, 0); 
    } 

同樣重要的是,免費的,否則我們試圖從內存中刪除陣列兩次去除。

-1

由於您指定了JNI_ABORT,所以不會複製回來。見JNI Specification

  • 0複製回內容並釋放elems緩衝區
  • JNI_COMMIT複製回內容但不釋放elems緩衝區
  • JNI_ABORT釋放緩衝區但不復制回變化。

按設計工作。

+0

我對int數組(和浮點數組之一)有JNI_ABORT,因爲我不需要將它們複製回代碼開始附近,但是我正在討論最後一行返回的行11:'env - > ReleasePrimitiveArrayCritical(j_stackOut,stackOutCopyBack,0);'我在哪裏使用零。 – danglingPointer

相關問題