2015-08-13 23 views
1

我有一個CUDA內核,它帶有一個結構列表。如何在JCuda中創建本機指針的結構

kernel<<<blockCount,blockSize>>>(MyStruct *structs); 

每個結構包含3個指針。

typedef struct __align(16)__ { 
    float* pointer1; 
    float* pointer2; 
    float* pointer3; 
} 

我有一個包含該結構點的範圍內漂浮,並且每個指針的三個裝置陣列中的一個內的浮子3個設備陣列。

結構列表表示允許內核執行遞歸操作的樹/圖結構,具體取決於發送給內核的結構列表的順序。 (這個位在C++中起作用,所以與我的問題沒有關係)

我想要做的就是能夠從JCuda發送指針結構。我知道這不是本地可能的,除非它被平鋪爲填充數組,如this post

我理解了發送結構列表時可能發生的所有對齊和填充問題,它本質上是一個重複的填充數組,我很好。

我不知道該怎麼辦的一點,是填充指針我的扁平結構緩衝,例如,我認爲我可以做這樣的事情:

Pointer A = ....(underlying device array1) 
Pointer B = ....(underlying device array2) 
Pointer C = ....(underlying device array3) 

ByteBuffer structListBuffer = ByteBuffer.allocate(16*noSteps); 
for(int x = 0; x<noSteps; x++) { 
    // Get the underlying pointer values 
    long pointer1 = A.withByteOffset(getStepOffsetA(x)).someGetUnderlyingPointerValueFunction(); 
    long pointer2 = B.withByteOffset(getStepOffsetB(x)).someGetUnderlyingPointerValueFunction(); 
    long pointer3 = C.withByteOffset(getStepOffsetC(x)).someGetUnderlyingPointerValueFunction(); 

    // Build the struct 
    structListBuffer.asLongBuffer().append(pointer1); 
    structListBuffer.asLongBuffer().append(pointer2); 
    structListBuffer.asLongBuffer().append(pointer3); 
    structListBuffer.asLongBuffer().append(0); //padding 
} 

structListBuffer隨後將包含列表以內核預期的方式構建。

那麼有沒有辦法從ByteBuffer做someGetUnderlyingPointerValueFunction()

回答

2

如果我理解正確的一切,問題的重點在於是否有這樣一個神奇的功能像

long address = pointer.someGetUnderlyingPointerValueFunction(); 

,返回原生指針的地址。

簡答:不,沒有這樣的功能。 (注意:前一段時間已經提出了類似的功能,但我還沒有添加它,主要是因爲這樣的函數對於指向Java數組或非直接字節的指針沒有意義此外,在32位和64位機器上手動處理帶有填充和對齊的結構以及大小不一的指針,以及大小不一的緩衝區都是令人頭痛的無盡資源,但我明白了這一點,應用案例,所以我很可能會添加類似getAddress()功能的東西,也許只有到CUdeviceptr類,它肯定是有意義的 - 至少比Pointer類更多。 ,他們做的事情,將導致惡意崩潰的虛擬機,但JCUDA本身是一個如此薄的抽象層,在這方面沒有任何安全網......)


這就是說,你可以解決當前的限制,用這樣的方法:

private static long getPointerAddress(CUdeviceptr p) 
{ 
    // WORKAROUND until a method like CUdeviceptr#getAddress exists 
    class PointerWithAddress extends Pointer 
    { 
     PointerWithAddress(Pointer other) 
     { 
      super(other); 
     } 
     long getAddress() 
     { 
      return getNativePointer() + getByteOffset(); 
     } 
    } 
    return new PointerWithAddress(p).getAddress(); 
} 

當然,這是醜陋,顯然違背使getNativePointer()getByteOffset()的意圖方法protected。但它最終可能會與一些「官方」的方法來代替:

private static long getPointerAddress(CUdeviceptr p) 
{ 
    return p.getAddress(); 
} 

直到現在,這可能是最接近你可以在C面做的解決方案。


這是我爲測試這個而寫的一個例子。內核僅僅是一個虛擬內核,填充有「身份」的價值觀的結構(看看他們是否在正確的位置上結束),並且應該只有1個線程啓動:

typedef struct __declspec(align(16)) { 
    float* pointer1; 
    float* pointer2; 
    float* pointer3; 
} MyStruct; 

extern "C" 
__global__ void kernel(MyStruct *structs) 
{ 
    structs[0].pointer1[0] = 1.0f; 
    structs[0].pointer1[1] = 1.1f; 
    structs[0].pointer1[2] = 1.2f; 

    structs[0].pointer2[0] = 2.0f; 
    structs[0].pointer2[1] = 2.1f; 
    structs[0].pointer2[2] = 2.2f; 

    structs[0].pointer3[0] = 3.0f; 
    structs[0].pointer3[1] = 3.1f; 
    structs[0].pointer3[2] = 3.2f; 

    structs[1].pointer1[0] = 11.0f; 
    structs[1].pointer1[1] = 11.1f; 
    structs[1].pointer1[2] = 11.2f; 

    structs[1].pointer2[0] = 12.0f; 
    structs[1].pointer2[1] = 12.1f; 
    structs[1].pointer2[2] = 12.2f; 

    structs[1].pointer3[0] = 13.0f; 
    structs[1].pointer3[1] = 13.1f; 
    structs[1].pointer3[2] = 13.2f; 
} 

這個內核在以下程序中啓動(注意: PTX文件的編譯在此處即時完成,其設置可能與您的應用程序案例不符。如有疑問,可以手動編譯您的PTX文件)。

pointer1,每個結構體的pointer2pointer3指針被初始化,使得它們指向設備緩衝器ABC的連續元素,每個分別具有偏移,允許識別由內核寫入的值。 (請注意,我想辦理某一32bit-或64位的機器,這意味着不同的指針sizese上運行這兩個可能的情況 - 雖然,目前,我只能測試32位版本)

import static jcuda.driver.JCudaDriver.*; 

import java.io.ByteArrayOutputStream; 
import java.io.File; 
import java.io.IOException; 
import java.io.InputStream; 
import java.nio.ByteBuffer; 
import java.nio.ByteOrder; 
import java.nio.IntBuffer; 
import java.nio.LongBuffer; 
import java.util.Arrays; 

import jcuda.Pointer; 
import jcuda.Sizeof; 
import jcuda.driver.CUcontext; 
import jcuda.driver.CUdevice; 
import jcuda.driver.CUdeviceptr; 
import jcuda.driver.CUfunction; 
import jcuda.driver.CUmodule; 
import jcuda.driver.JCudaDriver; 


public class JCudaPointersInStruct 
{ 
    public static void main(String args[]) throws IOException 
    { 
     JCudaDriver.setExceptionsEnabled(true); 
     String ptxFileName = preparePtxFile("JCudaPointersInStructKernel.cu"); 
     cuInit(0); 
     CUdevice device = new CUdevice(); 
     cuDeviceGet(device, 0); 
     CUcontext context = new CUcontext(); 
     cuCtxCreate(context, 0, device); 
     CUmodule module = new CUmodule(); 
     cuModuleLoad(module, ptxFileName); 
     CUfunction function = new CUfunction(); 
     cuModuleGetFunction(function, module, "kernel"); 

     int numElements = 9; 
     CUdeviceptr A = new CUdeviceptr(); 
     cuMemAlloc(A, numElements * Sizeof.FLOAT); 
     cuMemsetD32(A, 0, numElements); 
     CUdeviceptr B = new CUdeviceptr(); 
     cuMemAlloc(B, numElements * Sizeof.FLOAT); 
     cuMemsetD32(B, 0, numElements); 
     CUdeviceptr C = new CUdeviceptr(); 
     cuMemAlloc(C, numElements * Sizeof.FLOAT); 
     cuMemsetD32(C, 0, numElements); 

     int numSteps = 2; 
     int sizeOfStruct = Sizeof.POINTER * 4; 
     ByteBuffer hostStructsBuffer = 
      ByteBuffer.allocate(numSteps * sizeOfStruct); 
     if (Sizeof.POINTER == 4) 
     { 
      IntBuffer b = hostStructsBuffer.order(
       ByteOrder.nativeOrder()).asIntBuffer(); 
      for(int x = 0; x<numSteps; x++) 
      { 
       CUdeviceptr pointer1 = A.withByteOffset(getStepOffsetA(x)); 
       CUdeviceptr pointer2 = B.withByteOffset(getStepOffsetB(x)); 
       CUdeviceptr pointer3 = C.withByteOffset(getStepOffsetC(x)); 

       //System.out.println("Step "+x+" pointer1 is "+pointer1); 
       //System.out.println("Step "+x+" pointer2 is "+pointer2); 
       //System.out.println("Step "+x+" pointer3 is "+pointer3); 

       b.put((int)getPointerAddress(pointer1)); 
       b.put((int)getPointerAddress(pointer2)); 
       b.put((int)getPointerAddress(pointer3)); 
       b.put(0); 
      } 
     } 
     else 
     { 
      LongBuffer b = hostStructsBuffer.order(
       ByteOrder.nativeOrder()).asLongBuffer(); 
      for(int x = 0; x<numSteps; x++) 
      { 
       CUdeviceptr pointer1 = A.withByteOffset(getStepOffsetA(x)); 
       CUdeviceptr pointer2 = B.withByteOffset(getStepOffsetB(x)); 
       CUdeviceptr pointer3 = C.withByteOffset(getStepOffsetC(x)); 

       //System.out.println("Step "+x+" pointer1 is "+pointer1); 
       //System.out.println("Step "+x+" pointer2 is "+pointer2); 
       //System.out.println("Step "+x+" pointer3 is "+pointer3); 

       b.put(getPointerAddress(pointer1)); 
       b.put(getPointerAddress(pointer2)); 
       b.put(getPointerAddress(pointer3)); 
       b.put(0); 
      } 
     } 

     CUdeviceptr structs = new CUdeviceptr(); 
     cuMemAlloc(structs, numSteps * sizeOfStruct); 
     cuMemcpyHtoD(structs, Pointer.to(hostStructsBuffer), 
      numSteps * sizeOfStruct); 

     Pointer kernelParameters = Pointer.to(
      Pointer.to(structs) 
     ); 
     cuLaunchKernel(function, 
      1, 1, 1, 
      1, 1, 1, 
      0, null, kernelParameters, null); 
     cuCtxSynchronize(); 


     float hostA[] = new float[numElements]; 
     cuMemcpyDtoH(Pointer.to(hostA), A, numElements * Sizeof.FLOAT); 
     float hostB[] = new float[numElements]; 
     cuMemcpyDtoH(Pointer.to(hostB), B, numElements * Sizeof.FLOAT); 
     float hostC[] = new float[numElements]; 
     cuMemcpyDtoH(Pointer.to(hostC), C, numElements * Sizeof.FLOAT); 

     System.out.println("A "+Arrays.toString(hostA)); 
     System.out.println("B "+Arrays.toString(hostB)); 
     System.out.println("C "+Arrays.toString(hostC)); 
    } 

    private static long getStepOffsetA(int x) 
    { 
     return x * Sizeof.FLOAT * 4 + 0 * Sizeof.FLOAT; 
    } 
    private static long getStepOffsetB(int x) 
    { 
     return x * Sizeof.FLOAT * 4 + 1 * Sizeof.FLOAT; 
    } 
    private static long getStepOffsetC(int x) 
    { 
     return x * Sizeof.FLOAT * 4 + 2 * Sizeof.FLOAT; 
    } 


    private static long getPointerAddress(CUdeviceptr p) 
    { 
     // WORKAROUND until a method like CUdeviceptr#getAddress exists 
     class PointerWithAddress extends Pointer 
     { 
      PointerWithAddress(Pointer other) 
      { 
       super(other); 
      } 
      long getAddress() 
      { 
       return getNativePointer() + getByteOffset(); 
      } 
     } 
     return new PointerWithAddress(p).getAddress(); 
    } 




    //------------------------------------------------------------------------- 
    // Ignore this - in practice, you'll compile the PTX manually 
    private static String preparePtxFile(String cuFileName) throws IOException 
    { 
     int endIndex = cuFileName.lastIndexOf('.'); 
     if (endIndex == -1) 
     { 
      endIndex = cuFileName.length()-1; 
     } 
     String ptxFileName = cuFileName.substring(0, endIndex+1)+"ptx"; 
     File cuFile = new File(cuFileName); 
     if (!cuFile.exists()) 
     { 
      throw new IOException("Input file not found: "+cuFileName); 
     } 
     String modelString = "-m"+System.getProperty("sun.arch.data.model"); 
     String command = 
      "nvcc " + modelString + " -ptx -arch sm_11 -lineinfo "+ 
      cuFile.getPath()+" -o "+ptxFileName; 
     System.out.println("Executing\n"+command); 
     Process process = Runtime.getRuntime().exec(command); 
     String errorMessage = 
      new String(toByteArray(process.getErrorStream())); 
     String outputMessage = 
      new String(toByteArray(process.getInputStream())); 
     int exitValue = 0; 
     try 
     { 
      exitValue = process.waitFor(); 
     } 
     catch (InterruptedException e) 
     { 
      Thread.currentThread().interrupt(); 
      throw new IOException(
       "Interrupted while waiting for nvcc output", e); 
     } 

     if (exitValue != 0) 
     { 
      System.out.println("nvcc process exitValue "+exitValue); 
      System.out.println("errorMessage:\n"+errorMessage); 
      System.out.println("outputMessage:\n"+outputMessage); 
      throw new IOException(
       "Could not create .ptx file: "+errorMessage); 
     } 
     System.out.println("Finished creating PTX file"); 
     return ptxFileName; 
    } 
    private static byte[] toByteArray(InputStream inputStream) 
     throws IOException 
    { 
     ByteArrayOutputStream baos = new ByteArrayOutputStream(); 
     byte buffer[] = new byte[8192]; 
     while (true) 
     { 
      int read = inputStream.read(buffer); 
      if (read == -1) 
      { 
       break; 
      } 
      baos.write(buffer, 0, read); 
     } 
     return baos.toByteArray(); 
    } 

} 

的結果如預期/期望:

A [1.0, 1.1, 1.2, 0.0, 11.0, 11.1, 11.2, 0.0, 0.0] 
B [0.0, 2.0, 2.1, 2.2, 0.0, 12.0, 12.1, 12.2, 0.0] 
C [0.0, 0.0, 3.0, 3.1, 3.2, 0.0, 13.0, 13.1, 13.2] 
+0

這正是我所期待的。我打算嘗試製作一個Pointer.to(填充的NativePointer的大列表)來僞造結構列表,並希望底層的JNI東西能夠解析爲指針的內存緩衝區,但這樣做更有意義,並且不太冒險(儘管仍然有點hackey :))。 – Bam4d