2017-05-23 34 views
1

我從tutorial看到,我們可以做到這一點:如何訪問TensorFlow中protos的值?

for node in tf.get_default_graph().as_graph_def().node: print node

在任意網絡上完成的,我們得到許多鍵值對。例如:

name: "conv2d_2/convolution" 
op: "Conv2D" 
input: "max_pooling2d/MaxPool" 
input: "conv2d_1/kernel/read" 
device: "/device:GPU:0" 
attr { 
    key: "T" 
    value { 
    type: DT_FLOAT 
    } 
} 
attr { 
    key: "data_format" 
    value { 
    s: "NHWC" 
    } 
} 
attr { 
    key: "padding" 
    value { 
    s: "SAME" 
    } 
} 
attr { 
    key: "strides" 
    value { 
    list { 
     i: 1 
     i: 1 
     i: 1 
     i: 1 
    } 
    } 
} 
attr { 
    key: "use_cudnn_on_gpu" 
    value { 
    b: true 
    } 
} 

如何訪問所有這些值並將它們放入Python列表中?具體來說,我們如何獲得「strides」屬性並將其中的所有1轉換爲[1,1,1,1]?

回答

1

TLDR:下面的代碼是,你可能想使用什麼:

for n in tf.get_default_graph().as_graph_def().node: if 'strides' in n.attr.keys(): print n.name, [int(a) for a in n.attr['strides'].list.i] if 'shape' in n.attr.keys(): print n.name, [int(a.size) for a in n.attr['shape'].shape.dim]

訣竅這樣做是爲了瞭解protobufs是。我們來看看上面提到的tutorial

首先,有一種說法:

for node in graph_def.node

的每個節點都是一個NodeDef對象,在 tensorflow /核心/框架/ node_def.proto定義。這些是TensorFlow圖表的基本組成部分,每個圖表定義一個單一的 操作及其輸入連接。這裏是一個 NodeDef的成員,以及它們的含義。

注意在node_def.proto如下:

  • 它進口attr_value.proto。
  • 有屬性,如名稱,操作,輸入,設備,attr。具體來說,在輸入前面有一個repeated的術語。我們現在可以忽略這一點。

操作就像一個Python類,因此,我們可以稱之爲node.name,node.op,node.input,node.device,node.attr等

我們想訪問什麼現在將是node.attr中的內容。如果我們再次參考教程,它指定:

這是一個鍵/值存儲區,它包含節點的所有屬性。這些 是節點的永久屬性,不會在 運行時更改,例如卷積過濾器的大小或 常量操作的值。因爲可以有很多不同類型的屬性值,從字符串,到整數,到張量值數組, 有一個單獨的protobuf文件,定義了數據結構 保存它們,在tensorflow/core/framework/attr_value.proto 。

每個屬性都有唯一的名稱字符串,並且在定義操作時列出了預期的屬性 。如果節點中存在的屬性不是 ,但它具有在操作 定義中列出的默認值,則在創建圖形時會使用該默認值。

您可以通過在Python中調用node.name,node.op, 等來訪問所有這些成員。存儲在GraphDef中的節點列表是模型體系結構的完整定義。

由於這是一個鍵值存儲,我們可以調用n.attr.keys()來查看此屬性具有的鍵列表。如果這樣的密鑰可用,我們可以進一步呼叫或許n.attr['strides']訪問步幅。當我們嘗試打印此,我們得到如下:

list { 
    i: 1 
    i: 2 
    i: 2 
    i: 1 
} 

而這正是它開始變得混亂,因爲我們可能會嘗試做list(n.attr['strides'])或這類東西。如果我們看attr_value.proto,我們可以理解發生了什麼。我們看到它是oneof value,在這種情況下它是ListValue list,所以我們可以撥打n.attr['strides'].list。如果我們打印此,我們得到如下:

i: 1 
i: 1 
i: 1 
i: 1 

我們接下來會嘗試這樣做:[a for a in n.attr['strides'].list][a.i for a in n.attr['strides'].list]。但是,沒有任何工作。這是repeated是理解的重要術語。它基本上意味着有一個int64列表,您必須使用i屬性訪問它。做[int(a) for a in n.attr['strides'].list.i]然後給我們我們想要的,我們可以使用一個Python列表:

[1, 1, 1, 1]