2016-05-19 43 views
0

在tensorflow我寄存器的運算,如下所示:在tensorflow中,如何在移動到GPU之前訪問標量張量值?

REGISTER_OP("RimeBSqrt") 
    .Input("stokes: FT") 
    .Input("alpha: FT") 
    .Input("frequency: FT") 
    .Input("ref_freq: FT") 
    .Output("b_sqrt: CT") 
    .Attr("FT: {float, double} = DT_FLOAT") 
    .Attr("CT: {complex64, complex128} = DT_COMPLEX64"); 

上述所有輸入是張量, 但ref_freq是一個標量或0-d張量。 在我的CPU內核 的計算()方法,我可以做以下的提取標:

const Tensor & in_ref_freq = context->input(3); 
FT ref_freq = in_ref_freq.tensor<FT, 1>()(0); 

然而,相同類型的代碼在我的GPU內核的計算()方法生成一個段錯誤 ,因爲 CPU現在嘗試訪問 GPU設備上的一塊內存。無論如何在將其發送到GPU之前攔截此標量 的值?我想,以避免 內存間接以下額外水平 一個CUDA內核:

template <typename FT> 
__global__ void kernel(..., FT * ref_freq, ...) 
{ 
    FT value = ref_freq[0]; 
} 

我不認爲Attr是用於ref_freq,因爲它是多變的,可配置的值的方法。

  1. CPU Tensorflow內核代碼是here
  2. GPU Tensorflow內核代碼是here
  3. Python變量設置代碼是here

回答

3

可以指定(來自或輸出)的一個或多個的輸入到一個TensorFlow OpKernel是在「主存儲器」,它可以訪問在該值Compute()方法。要做到這一點,你會修改你的REGISTER_KERNEL_BUILDER()電話添加.HostMemory("ref_freq")指令:

REGISTER_KERNEL_BUILDER(
    Name("RimeBSqrt") 
    .Device(tensorflow::DEVICE_GPU) 
    .TypeConstraint<float>("FT") 
    .TypeConstraint<tensorflow::complex64>("CT") 
    .HostMemory("ref_freq"), 
    RimeBSqrt<tensorflow::GPUDevice, float, tensorflow::complex64>); 
+0

Thanks! HostMemory指令是否也阻止了ref_freq到GPU的傳輸? – Simon

+0

HostMemory指令將阻止它由運行時自動複製。我絕不是CUDA專家,但我想你會修改'OpKernel'來從張量中提取float值,並將其作爲float參數傳遞給CUDA內核(就像你爲'int'參數[here](https://github.com/ska-sa/montblanc/blob/53edf2ba505e4b5b10ae89e187c4f11d1e7072db/montblanc/tensorflow/rime_ops/b_sqrt_op_gpu.h#L71))。 – mrry

+0

太好了,那正是我在找的東西。 – Simon