2017-05-01 95 views
1

我想在Python中編寫一個自定義Tensorflow操作並將其註冊到Protobuf註冊中心以執行操作,如解釋here。 Protobuf註冊是關鍵,因爲我不會直接從Python使用這個操作,但是如果它被註冊爲像C++ op並加載到Python運行時環境,那麼我可以在我的環境中運行它。在Python中編寫和註冊一個自定義Tensorflow操作

我希望的代碼看起來像,上面

import tensorflow as tf 
from google.protobuf import json_format 
from tensorflow.python.ops.data_flow_ops import QueueBase, _as_type_list, _as_shape_list, _as_name_list 

""" Missing the Python equivalent of,                                           

    class HDF5QueueOp : public ResourceOpKernel<QueueInterface> {                                    
    public:                                                  
     // Implementation                                              
    };                                                   

    REGISTER_OP("HDF5Queue")                                             
    .Output("handle: resource")                                             
    .Attr("filename: string")                                             
    .Attr("datasets: list(string)")                                            
    .Attr("overwrite: bool = false")                                           
    .Attr("component_types: list(type) >= 0 = []")                                        
    .Attr("shapes: list(shape) >= 0 = []")                                          
    .Attr("shared_name: string = ''")                                           
    .Attr("container: string = ''")                                            
    .Attr("capacity: int = -1")                                             
    .SetIsStateful()                                               
    .SetShapeFn(TwoElementOutput);                                            

""" 

class HDF5Queue(QueueBase): 
    def __init__(self, stream_id, stream_columns, dtypes=None, capacity=100, 
       shapes=None, names=None, name="hdf5_queue"): 
    if not dtypes: 
     dtypes = [tf.int64, tf.float32] 

    if not shapes: 
     shapes = [[1], [1]] 

    dtypes = _as_type_list(dtypes) 
    shapes = _as_shape_list(shapes, dtypes) 
    names = _as_name_list(names, dtypes) 
    queue_ref = _op_def_lib.apply_op("HDF5Queue", stream_id=stream_id, 
            stream_columns=stream_columns, capacity=capacity, 
            component_types=dtypes, shapes=shapes, 
            name=name, container=None, shared_name=None) 
    super(HDF5Queue, self).__init__(dtypes, shapes, 
            names, queue_ref) 

是TF非常標準。例如可以用FIFOQueue來看。 Python Wrapper,Protobuf Registration,C++ Implementation。在編譯期間生成了一個我不喜歡的Python包裝器,但是您可以看到它在運行時使用的位置grep -A 10 -B 10 -n FIFO $(find /usr/local -name "*gen_data_flow*.py") /dev/null

下面將以JSON格式轉儲TF圖形的Protobuf消息。我希望這會轉儲HDF5Queue操作的塊,就像我寫C++操作一樣。

with tf.Session() as sess: 
    queue = HDF5Queue(stream_id=0xa) 
    write = queue.enqueue([[1], [1.2]]) 
    read = queue.dequeue() 
    print json_format.MessageToJson(tf.train.export_meta_graph()) 

回答

0

這可以使用py_func來完成。這是一個例子。

import tensorflow as tf 
from google.protobuf import json_format 
import sys, json, base64, numpy 
from tensorflow.python.ops.script_ops import _py_funcs as py_func_registry 
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef 

graph = tf.Graph() 
graph2 = tf.Graph() 

def f(x): 
    return x 

def g(x): 
    return 2*x 

with graph.as_default(): 
    x = tf.placeholder(tf.float32, shape=(3,), name='x') 
    y = tf.py_func(f, [x], tf.float32, name='y') 

    # py_func_registry._funcs.clear() # Optional line to clear the Python function registry 
    msg = json.loads(json_format.MessageToJson(tf.train.export_meta_graph())) 

# Change the function being used by py_func 
msg['graphDef']['node'][1]['attr']['token']['s'] = base64.b64encode(py_func_registry.insert(g)) 

with graph2.as_default():  
    # Load graph 
    meta_graph_def = MetaGraphDef() 
    json_format.Parse(json.dumps(msg), meta_graph_def) 
    tf.train.import_meta_graph(meta_graph_def) 

    sess = tf.Session(graph=graph2) 
    print sess.run('y:0', feed_dict={'x:0':numpy.array([1, 2, 3])}) 
    print g(numpy.array([1, 2, 3]))