2016-05-29 48 views
0

我已經實現了我的自定義操作的內核,並將其作爲作爲custom_op.cc。在操作內部,我做所有註冊的東西,如REGISTER_OPREGISTER_KERNEL_BUILDER如何使TensorFlow中的自定義操作可以在Python中導入?

然後我在Python中爲這個操作符實現了漸變,並將它放在與custom_op_grad.py相同的文件夾中。我也在這裏做了所有的註冊(@ops.RegisterGradient)。

我已創建的BUILD文件,包含以下內容:

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") 
tf_custom_op_library(
     name = "custom_op.so", 
     srcs = ["custom_op.cc"], 
) 

py_library(
     name = "custom_op_grad", 
     srcs = ["custom_op_grad.py"], 
     srcs_version = "PY2", 
     deps = [ 
     ":custom_op_grad", 
     "//tensorflow:tensorflow_py", 
     ], 
) 

在那之後,我重建Tensorflow:

pip uninstall tensorflow 
bazel clean 
bazel build -c opt //tensorflow/tools/pip_package:build_pip_package 
cp -r bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/__main__/* bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/ 
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg 
pip install /tmp/tensorflow_pkg/tensorflow-0.8.0-py2-none-any.whl 

當我嘗試使用我的作品在這一切之後,通過調用它告訴我模塊沒有它。

也許有一些額外的步驟,我必須做的?或者我在BUILD文件中做錯了什麼?

回答

0

好吧,我找到了解決方案。我剛剛刪除了BUILD文件,並且我的自定義操作已成功構建,並且可以使用tensorflow.user_ops.custom_op()在Python中導入。

要使用漸變我不得​​不把它的代碼直接在tensorflow/python/user_ops/user_ops.py內。不是最優雅的解決方案,但現在工作。

相關問題