2017-03-23 40 views
3

我正在編寫一個使用本教程的自定義Tensorflow操作,並且我無法理解如何讀取和寫入張量。Tensorflow自定義操作 - 如何從Tensors讀取和寫入?

比方說,我在我的OpKernel張量,我從拿到 const Tensor& values_tensor = context->input(0);(其中上下文= OpKernelConstruction*

如果張量具有形狀,也就是說,[2,10,20],哪能索引它(例如auto x = values_tensor[1, 4, 12]等)?

等價,如果我有

Tensor *output_tensor = NULL; 
OP_REQUIRES_OK(context, context->allocate_output(
    0, 
    {batch_size, value_len - window_size, window_size}, 
    &output_tensor 
)); 

我怎麼可以分配給output_tensor,像output_tensor[1, 2, 3] = 11等?

抱歉愚蠢的問題,但該文檔真的絆倒了我這裏,並在內置的OPS的Tensorflow內核代碼的例子在某種程度上混淆這是我得到很困惑:)

感謝點您!

回答

1

讀取和寫入tensorflow::Tensor對象的最簡單方法是使用tensorflow::Tensor::tensor<T, NDIMS>()方法將它們轉換爲Eigen tensor。請注意,您必須在張量中指定(C++)類型的元素作爲模板參數T

例如,爲了從一個DT_FLOAT32張量讀取一個特定的值:

const Tensor& values_tensor = context->input(0); 
auto x = value_tensor.tensor<float, 3>()(1, 4, 12); 

爲特定值寫入DT_FLOAT32張量:

Tensor* output_tensor = ...; 
output_tensor->tensor<float, 3>()(1, 2, 3) = 11.0; 

也有用於訪問scalar方便的方法,vectormatrix