0
在tensorflow我寄存器的運算,如下所示:在tensorflow中,如何在移動到GPU之前訪問標量張量值?
REGISTER_OP("RimeBSqrt")
.Input("stokes: FT")
.Input("alpha: FT")
.Input("frequency: FT")
.Input("ref_freq: FT")
.Output("b_sqrt: CT")
.Attr("FT: {float, double} = DT_FLOAT")
.Attr("CT: {complex64, complex128} = DT_COMPLEX64");
上述所有輸入是張量, 但ref_freq是一個標量或0-d張量。 在我的CPU內核 的計算()方法,我可以做以下的提取標:
const Tensor & in_ref_freq = context->input(3);
FT ref_freq = in_ref_freq.tensor<FT, 1>()(0);
然而,相同類型的代碼在我的GPU內核的計算()方法生成一個段錯誤 ,因爲 CPU現在嘗試訪問 GPU設備上的一塊內存。無論如何在將其發送到GPU之前攔截此標量 的值?我想,以避免 內存間接以下額外水平 一個CUDA內核:
template <typename FT>
__global__ void kernel(..., FT * ref_freq, ...)
{
FT value = ref_freq[0];
}
我不認爲Attr
是用於ref_freq
,因爲它是多變的,可配置的值的方法。
Thanks! HostMemory指令是否也阻止了ref_freq到GPU的傳輸? – Simon
HostMemory指令將阻止它由運行時自動複製。我絕不是CUDA專家,但我想你會修改'OpKernel'來從張量中提取float值,並將其作爲float參數傳遞給CUDA內核(就像你爲'int'參數[here](https://github.com/ska-sa/montblanc/blob/53edf2ba505e4b5b10ae89e187c4f11d1e7072db/montblanc/tensorflow/rime_ops/b_sqrt_op_gpu.h#L71))。 – mrry
太好了,那正是我在找的東西。 – Simon