2017-02-15 87 views
4

如何列出節點依賴的所有Tensorflow變量/常量/佔位符?如何列出節點依賴的所有Tensorflow變量?

實施例1(除了常數):

import tensorflow as tf 

a = tf.constant(1, name = 'a') 
b = tf.constant(3, name = 'b') 
c = tf.constant(9, name = 'c') 
d = tf.add(a, b, name='d') 
e = tf.add(d, c, name='e') 

sess = tf.Session() 
print(sess.run([d, e])) 

我想有一個功能list_dependencies()如:

  • list_dependencies(d)返回['a', 'b']
  • list_dependencies(e)返回['a', 'b', 'c']

實施例2(佔位符和權重矩陣之間的矩陣乘法,隨後加入的偏置矢量的):

tf.set_random_seed(1) 
input_size = 5 
output_size = 3 
input  = tf.placeholder(tf.float32, shape=[1, input_size], name='input') 
W   = tf.get_variable(
       "W", 
       shape=[input_size, output_size], 
       initializer=tf.contrib.layers.xavier_initializer()) 
b   = tf.get_variable(
       "b", 
       shape=[output_size], 
       initializer=tf.constant_initializer(2)) 
output  = tf.matmul(input, W, name="output") 
output_bias = tf.nn.xw_plus_b(input, W, b, name="output_bias") 

sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print(sess.run([output,output_bias], feed_dict={input: [[2]*input_size]})) 

我想有一個功能,諸如list_dependencies()

  • list_dependencies(output)返回['W', 'input']
  • list_dependencies(output_bias)回報['W', 'b', 'input']

回答

6

這裏是我事業(從https://github.com/yaroslavvb/stuff/blob/master/linearize/linearize.py

# computation flows from parents to children 

def parents(op): 
    return set(input.op for input in op.inputs) 

def children(op): 
    return set(op for out in op.outputs for op in out.consumers()) 

def get_graph(): 
    """Creates dictionary {node: {child1, child2, ..},..} for current 
    TensorFlow graph. Result is compatible with networkx/toposort""" 

    ops = tf.get_default_graph().get_operations() 
    return {op: children(op) for op in ops} 


def print_tf_graph(graph): 
    """Prints tensorflow graph in dictionary form.""" 
    for node in graph: 
    for child in graph[node]: 
     print("%s -> %s" % (node.name, child.name)) 

這些函數在ops上工作。要獲得產生張量t的操作,請使用t.op。要獲得通過運算op產生張量,使用op.outputs

+1

在[graph_util](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/python/framework/graph_util_impl.py?q)中貢獻它可能是個好主意= file:third_party/tensorflow。* graph_util&sq = package:piper + file:// depot/google3 + -file:google3/experimental&dr&l = 110)或通過contrib。 – drpng

+0

似乎這個解決方案將返回圖中的所有子操作符,而不僅僅是特定節點的操作符。 –

1

Yaroslav Bulatov's answer是偉大的,我就添加使用雅羅斯拉夫的get_graph()children()方法一個繪圖功能:

import matplotlib.pyplot as plt 
import networkx as nx 
def plot_graph(G): 
    '''Plot a DAG using NetworkX'''   
    def mapping(node): 
     return node.name 
    G = nx.DiGraph(G) 
    nx.relabel_nodes(G, mapping, copy=False) 
    nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True) 
    plt.show() 

plot_graph(get_graph()) 

從問題繪製的例子1:

import matplotlib.pyplot as plt 
import networkx as nx 
import tensorflow as tf 

def children(op): 
    return set(op for out in op.outputs for op in out.consumers()) 

def get_graph(): 
    """Creates dictionary {node: {child1, child2, ..},..} for current 
    TensorFlow graph. Result is compatible with networkx/toposort""" 
    print('get_graph') 
    ops = tf.get_default_graph().get_operations() 
    return {op: children(op) for op in ops} 

def plot_graph(G): 
    '''Plot a DAG using NetworkX'''   
    def mapping(node): 
     return node.name 
    G = nx.DiGraph(G) 
    nx.relabel_nodes(G, mapping, copy=False) 
    nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True) 
    plt.show() 

a = tf.constant(1, name = 'a') 
b = tf.constant(3, name = 'b') 
c = tf.constant(9, name = 'c') 
d = tf.add(a, b, name='d') 
e = tf.add(d, c, name='e') 

sess = tf.Session() 
print(sess.run([d, e])) 
plot_graph(get_graph()) 

輸出:

enter image description here

繪製的例子2從問題:

enter image description here

如果您使用Microsoft Windows,你可能會碰到這個問題:Python Error (ValueError: _getfullpathname: embedded null character),在這種情況下,你需要修補matplotlib爲紐帶解釋。

+1

順便說一句,如果你使用jupyter,你也可以使用http://stackoverflow.com/questions/38189119/simple-way-to-visualize-a-tensorflow-graph-in-jupyter,它可以讓你摺疊一些節點 –

相關問題