2016-08-28 49 views
2

我的問題是有關這個Tensorflow: How to get a tensor by name?TensorFlow:如何命名爲操作tf.get_variable

我可以給名稱運營。但實際上他們的名字不同。 例如:

In [11]: with tf.variable_scope('test_scope') as scope: 
    ...:  a = tf.get_variable('a',[1]) 
    ...:  b = tf.maximum(1,2, name='b') 
    ...:  print a.name 
    ...:  print b.name 
    ...:  
    ...:  
    ...:  
test_scope/a:0 
test_scope_1/b:0 

In [12]: with tf.variable_scope('test_scope') as scope: 
    ...:  scope.reuse_variables() 
    ...:  a = tf.get_variable('a',[1]) 
    ...:  b = tf.maximum(1,2, name='b') 
    ...:  print a.name 
    ...:  print b.name 
    ...:  
    ...:  
    ...:  
test_scope/a:0 
test_scope_2/b:0 

tf.get_variable具有完全相同的名稱創建變量我問。操作爲範圍添加前綴。

我想命名我的操作,以便我可以得到它。在我的情況下,我想在我的範圍內獲得btf.get_variable('b')

我該怎麼辦?我不能這樣做tf.Variable由於這個問題https://github.com/tensorflow/tensorflow/issues/1325 可能是我需要設置變量範圍,或操作,或以某種方式使用tf.get_variable增加參數?

回答

3

tf.get_variable()將無法​​正常工作。因此,我將定義一個新的變量存儲tf.maximum(1,2)再取回:

import tensorflow as tf 

with tf.variable_scope('test_scope') as scope: 
    a1 = tf.get_variable('a', [1]) 
    b1 = tf.get_variable('b', initializer=tf.maximum(1, 2)) 

with tf.variable_scope('test_scope') as scope: 
    scope.reuse_variables() 
    a2 = tf.get_variable('a', [1]) 
    b2 = tf.get_variable('b', dtype=tf.int32) 

assert a1 == a2 
assert b1 == b2 

請注意,你需要爲了使用tf.get_variable()再取回定義b

+0

謝謝!我不確定我完全理解初始化程序是什麼。它是否適用於可訓練和不可訓練的變量? – ckorzhik

+0

不客氣!初始化器只是一個將變量設置爲初始值的操作。是的,它應該適用於可訓練和不可訓練的變量。對於可訓練變量,需要從tf.get_variable()中設置參數'trainable = True',以將該變量添加到可訓練變量的集合中。 – rvinas

3

我不同意@rvinas的回答,你不需要創建一個變量來保存你想要檢索的張量的值。你可以只用graph.get_tensor_by_name使用正確的名稱檢索您的張量:

with tf.variable_scope('test_scope') as scope: 
    a = tf.get_variable('a',[1]) 
    b = tf.maximum(1,2, name='b') 

print a.name # should print 'test_scope/a:0' 
print b.name # should print 'test_scope/b:0' 

現在要重新創建相同的範圍內,並取回ab
對於b,你甚至不需要在範圍內,你只需要確切的名字b

with tf.variable_scope('test_scope') as scope: 
    scope.reuse_variables() 
    a2 = tf.get_variable('a', [1]) 

graph = tf.get_default_graph() 
b2 = graph.get_tensor_by_name('test_scope/b:0') 

assert a == a2 
assert b == b2 
+0

您的解決方案看起來很好。只是澄清,我存儲的張量的價值,因爲:「在我的情況下,我想在我的範圍內獲得b與tf.get_variable('b')」 – rvinas

+0

這兩個解決方案的工作,但現在我不知道是哪一個更'張量流'的方式。我剛剛開始用tensorflow編碼,不知道哪些模式好,哪些不好。在你的情況下,我需要維持內部張量變量名稱,在@ rvinas的情況下,我只需要我的變量名稱。現在我以這種方式構建代碼,只需要python變量。我需要在他們的回購中讀取良好的tensorflow代碼,閱讀tensorflow白皮書和教程。可能在此之後,對我而言,特定情況下哪種方式更好。 – ckorzhik