2017-02-06 64 views
3

我有n(例如:n = 3)範圍和x(例如:x = 4)每個範圍中定義的變量號。 的範圍是:tf.get_collection提取一個範圍的變量

model/generator_0 
model/generator_1 
model/generator_2 

一旦我計算的損失,我想提取,並提供一切僅從範圍的一個基於在運行時一個標準的變量。因此範圍idx,我選擇的指數澆鑄成INT32

<tf.Tensor 'model/Cast:0' shape=() dtype=int32> 

的argmin張我已經嘗試:

train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'model/generator_'+tf.cast(idx, tf.string)) 

這顯然沒有奏效。 有沒有辦法讓所有屬於該特定範圍的變量都使用idx傳入優化器。

在此先感謝!

維涅什斯里尼瓦桑

回答

3

你可以做這樣的事情在TF 1.0 RC1或更高版本:

v = tf.Variable(tf.ones(())) 
loss = tf.identity(v) 
with tf.variable_scope('adamoptim') as vs: 
    optim = tf.train.AdamOptimizer(learning_rate=0.1).minimize(loss) 
optim_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=vs.name) 
print([v.name for v in optim_vars]) #=> prints lists of vars created 
+0

不幸的是,optim_vars包含所有範圍的所有變量。我的問題是,我有 〜 與tf.variable_scope( 'generator_0'): V1 = tf.Variable(tf.ones(())) loss1 = tf.identity(V1) 與tf.variable_scope( 'generator_1'): v2 = tf.Variable(tf.ones(())) loss2 = tf.identity(v2) 〜 然後我想找到tf.minimum(loss,loss2)並傳遞變量屬於優化者的損失。 – viggie

+0

我剛剛在最新版本(夜間)中測試過,'optim_vars'只包含來自'adapoptim'範圍的變量 –

+0

您正在使用我之前不知道的定義的相同損耗。我必須找到'min_idx = tf.cast(tf.argmin(tf.stack([loss1,loss2]),0),tf.int32)',然後傳遞這個'loss = tf.stack([loss1,loss2 ])[min_idx]'到優化器。這就是爲什麼我得到所有使用你的建議定義的變量。 – viggie