2017-08-07 27 views
1

給定一些要提取的符號變量,我需要知道哪些佔位符是依賴項。如何確定TensorFlow中的佔位符依賴關係

在Theano,我們有:

import theano as th 
import theano.tensor as T 

x, y, z = T.scalars('xyz') 
u, v = x*y, y*z 
w = u + v 

th.gof.graph.inputs([w]) # gives [x, y, z] 
th.gof.graph.inputs([u]) # gives [x, y] 
th.gof.graph.inputs([v]) # gives [y, z] 
th.gof.graph.inputs([u, v]) # gives [x, y, z] 

如何做同樣的事情在TensorFlow?

回答

1

這裏沒有一個內置函數(即我所知道的),但它很容易使一個:

# Setup a graph 
import tensorflow as tf 
placeholder0 = tf.placeholder(tf.float32, []) 
placeholder1 = tf.placeholder(tf.float32, []) 
constant0 = tf.constant(2.0) 
sum0 = tf.add(placeholder0, constant0) 
sum1 = tf.add(placeholder1, sum0) 

# Function to get *all* dependencies of a tensor. 
def get_dependencies(tensor): 
    dependencies = set() 
    dependencies.update(tensor.op.inputs) 
    for sub_op in tensor.op.inputs: 
     dependencies.update(get_dependencies(sub_op)) 
    return dependencies 

print(get_dependencies(sum0)) 
print(get_dependencies(sum1)) 
# Filter on type to get placeholders. 
print([tensor for tensor in get_dependencies(sum0) if tensor.op.type == 'Placeholder']) 
print([tensor for tensor in get_dependencies(sum1) if tensor.op.type == 'Placeholder']) 

當然,你可以扔佔位過濾到功能以及。

+0

這種遞歸在圖形等「斐波那契」上可能效率很低,但確實是一個很好的起點。 – Kh40tiK

相關問題