1
我想在Python中編寫一個自定義Tensorflow操作並將其註冊到Protobuf註冊中心以執行操作,如解釋here。 Protobuf註冊是關鍵,因爲我不會直接從Python使用這個操作,但是如果它被註冊爲像C++ op並加載到Python運行時環境,那麼我可以在我的環境中運行它。在Python中編寫和註冊一個自定義Tensorflow操作
我希望的代碼看起來像,上面
import tensorflow as tf
from google.protobuf import json_format
from tensorflow.python.ops.data_flow_ops import QueueBase, _as_type_list, _as_shape_list, _as_name_list
""" Missing the Python equivalent of,
class HDF5QueueOp : public ResourceOpKernel<QueueInterface> {
public:
// Implementation
};
REGISTER_OP("HDF5Queue")
.Output("handle: resource")
.Attr("filename: string")
.Attr("datasets: list(string)")
.Attr("overwrite: bool = false")
.Attr("component_types: list(type) >= 0 = []")
.Attr("shapes: list(shape) >= 0 = []")
.Attr("shared_name: string = ''")
.Attr("container: string = ''")
.Attr("capacity: int = -1")
.SetIsStateful()
.SetShapeFn(TwoElementOutput);
"""
class HDF5Queue(QueueBase):
def __init__(self, stream_id, stream_columns, dtypes=None, capacity=100,
shapes=None, names=None, name="hdf5_queue"):
if not dtypes:
dtypes = [tf.int64, tf.float32]
if not shapes:
shapes = [[1], [1]]
dtypes = _as_type_list(dtypes)
shapes = _as_shape_list(shapes, dtypes)
names = _as_name_list(names, dtypes)
queue_ref = _op_def_lib.apply_op("HDF5Queue", stream_id=stream_id,
stream_columns=stream_columns, capacity=capacity,
component_types=dtypes, shapes=shapes,
name=name, container=None, shared_name=None)
super(HDF5Queue, self).__init__(dtypes, shapes,
names, queue_ref)
是TF非常標準。例如可以用FIFOQueue來看。 Python Wrapper,Protobuf Registration,C++ Implementation。在編譯期間生成了一個我不喜歡的Python包裝器,但是您可以看到它在運行時使用的位置grep -A 10 -B 10 -n FIFO $(find /usr/local -name "*gen_data_flow*.py") /dev/null
下面將以JSON格式轉儲TF圖形的Protobuf消息。我希望這會轉儲HDF5Queue操作的塊,就像我寫C++操作一樣。
with tf.Session() as sess:
queue = HDF5Queue(stream_id=0xa)
write = queue.enqueue([[1], [1.2]])
read = queue.dequeue()
print json_format.MessageToJson(tf.train.export_meta_graph())