如何列出節點依賴的所有Tensorflow變量/常量/佔位符?如何列出節點依賴的所有Tensorflow變量?
實施例1(除了常數):
import tensorflow as tf
a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')
sess = tf.Session()
print(sess.run([d, e]))
我想有一個功能list_dependencies()
如:
list_dependencies(d)
返回['a', 'b']
list_dependencies(e)
返回['a', 'b', 'c']
實施例2(佔位符和權重矩陣之間的矩陣乘法,隨後加入的偏置矢量的):
tf.set_random_seed(1)
input_size = 5
output_size = 3
input = tf.placeholder(tf.float32, shape=[1, input_size], name='input')
W = tf.get_variable(
"W",
shape=[input_size, output_size],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.get_variable(
"b",
shape=[output_size],
initializer=tf.constant_initializer(2))
output = tf.matmul(input, W, name="output")
output_bias = tf.nn.xw_plus_b(input, W, b, name="output_bias")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run([output,output_bias], feed_dict={input: [[2]*input_size]}))
我想有一個功能,諸如list_dependencies()
:
list_dependencies(output)
返回['W', 'input']
list_dependencies(output_bias)
回報['W', 'b', 'input']
在[graph_util](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/python/framework/graph_util_impl.py?q)中貢獻它可能是個好主意= file:third_party/tensorflow。* graph_util&sq = package:piper + file:// depot/google3 + -file:google3/experimental&dr&l = 110)或通過contrib。 – drpng
似乎這個解決方案將返回圖中的所有子操作符,而不僅僅是特定節點的操作符。 –