2012-07-04 11 views
1

我一直在嘗試在矩陣乘法中使用jcuda中的cublasSgemmBatched()函數,我不確定如何正確處理批處理矩陣的指針傳遞和向量。如果有人知道如何修改我的代碼以正確處理這個問題,我會非常感激。在這個例子中,C數組在cublasGetVector後保持不變。cublasSgemm與jcuda的配合使用

public static void SsmmBatchJCublas(int m, int n, int k, float A[], float B[]){ 

    // Create a CUBLAS handle 
    cublasHandle handle = new cublasHandle(); 
    cublasCreate(handle); 

    // Allocate memory on the device 
    Pointer d_A = new Pointer(); 
    Pointer d_B = new Pointer(); 
    Pointer d_C = new Pointer(); 


    cudaMalloc(d_A, m*k * Sizeof.FLOAT); 
    cudaMalloc(d_B, n*k * Sizeof.FLOAT); 
    cudaMalloc(d_C, m*n * Sizeof.FLOAT); 

    float[] C = new float[m*n]; 
    // Copy the memory from the host to the device 
    cublasSetVector(m*k, Sizeof.FLOAT, Pointer.to(A), 1, d_A, 1); 
    cublasSetVector(n*k, Sizeof.FLOAT, Pointer.to(B), 1, d_B, 1); 
    cublasSetVector(m*n, Sizeof.FLOAT, Pointer.to(C), 1, d_C, 1); 

    Pointer[] Aarray = new Pointer[]{d_A}; 
    Pointer AarrayPtr = Pointer.to(Aarray); 
    Pointer[] Barray = new Pointer[]{d_B}; 
    Pointer BarrayPtr = Pointer.to(Barray); 
    Pointer[] Carray = new Pointer[]{d_C}; 
    Pointer CarrayPtr = Pointer.to(Carray); 

    // Execute sgemm 
    Pointer pAlpha = Pointer.to(new float[]{1}); 
    Pointer pBeta = Pointer.to(new float[]{0}); 


    cublasSgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, pAlpha, AarrayPtr, Aarray.length, BarrayPtr, Barray.length, pBeta, CarrayPtr, Carray.length, Aarray.length); 
    // Copy the result from the device to the host 
    cublasGetVector(m*n, Sizeof.FLOAT, d_C, 1, Pointer.to(C), 1); 

    // Clean up 
    cudaFree(d_A); 
    cudaFree(d_B); 
    cudaFree(d_C); 
    cublasDestroy(handle); 
} 

回答

1

我問官方jcuda論壇,並很快收到了答案here

+0

請編輯此答案以包含解決方案。 – ThiefMaster