2017-08-26 94 views
1

我有一個這樣的函數來構建一個網絡。使用tf.contrib.graph_editor克隆網絡

def build_network(inputs): 
    # Some arbitrary set of variables and ops here. For example... 
    out = tf.contrib.layers.fully_connected(inputs, 123) 
    (...) 
    return out 

然後我用它來構建這樣的網絡。

inputs = tf.placeholder(...) 
outputs = build_network(inputs) 

如果我想建立更多的網絡結構相同,但自變量我只是必須再次調用build_network下,其他一些變量的作用域和可選的其他投入。

我的問題是:如果build_network不再可用,但原始網絡的輸入和輸出是?換句話說:如何將輸出一直到輸入的整個子圖克隆到另一個具有自己獨立變量但相同結構的變量作用域中?

我的理解是,通常tf.contrib.graph_editor和特別是graph_editor.copy正是我需要做這些事情的工具。但是,我找不到任何有用的例子。有什麼建議麼?

回答

1

迴應自己,我發現看起來像複製子圖的方式。

from tensorflow.contrib import graph_editor as ge 

# From the example above. 
inputs = [tf.placeholder(...), ...] 
outputs = build_network(inputs) 

sgv = ge.make_view(ge.get_within_boundary_ops(
    tf.get_default_graph(), 
    [t.op for t in outputs], 
    [t.op for t in inputs])) 

# This could be any new inputs. In this example I build new identical placeholders. 
new_inputs = {p: tf.placeholder(dtype=p.dtype, shape=p.shape) for p in inputs} 
new_sgv, info = ge.copy_with_input_replacements(sgv, new_inputs, dst_scope='copy') 

new_inputs = [info.transformed(t) for t in inputs] 
new_outputs = [info.transformed(t) for t in outputs] 

不過,現在我想利用網絡複製時面臨的一個新問題。副本中的新變量未初始化,並嘗試運行tf.global_variables_initializer()不起作用。

原因是因爲這些的tf.Variable從來沒有構建過,所以它們不是GlobalKeys.GLOBAL_VARIABLES集合的一部分。我可以很容易地找到與這些變量相對應的操作以及原始和副本之間的映射,但是我無法從中創建一個tf.Variable。

我發現了一個hacky解決方法來執行初始化,但它只適用於集合中的變量。

init_ops = [] 
for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): 
    if v.op in sgv.ops: 
    init_ops.append(info.transformed(v.initializer)) 

... 

session.run([tf.global_variables_initializer()] + init_ops) 

有沒有更好的方法來做到這一點?理想情況下,允許爲複製的變量創建tf.Variables將它們添加到全局變量集合中。或者,如果這是不可能的,至少有一種可靠的方法來獲得初始化器的操作,而不必查找原始網絡的tf.Variable對象。