1
我在C++中定義了一個新的Op,它採用tensor
類型的單個屬性,大致在these instructions之後。操作碼的剝離版本如下:TensorFlow接受類型爲「張量」的Attr的Python類型是什麼?
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
REGISTER_OP("DoStuff")
.Attr("attr: tensor = { dtype: DT_FLOAT }")
.Input("in: float")
.Output("out: float");
class DoStuffOp : public OpKernel {
public:
explicit DoStuffOp(OpKernelConstruction *context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("attr", &attr_));
// ...
}
void Compute(OpKernelContext *context) override {
// ...
}
private:
Tensor attr_;
};
REGISTER_KERNEL_BUILDER(Name("DoStuff").Device(DEVICE_CPU), DoStuffOp);
我可以編譯運到.so
文件罰款。但是,我無法弄清楚如何成功傳入attr
的值。當我在Python中運行以下代碼:
import tensorflow as tf
dostufflib = tf.load_op_library('build/do_stuff.so')
sess = tf.InteractiveSession()
A = [[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]]
X = tf.Variable(tf.constant(1.0))
Y = dostufflib.do_stuff(X, A)
我得到TypeError: Don't know how to convert [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] to a TensorProto for argument 'attr'
。我所做的任何事似乎都不能滿足類型轉換:list
,numpy
array,tf.Tensor
,tf.Variable
等。如何將Python變量作爲張量屬性傳遞到Op中?