2017-08-25 53 views
1

我需要創建一個變量epsilon_n,該變量基於當前的step更改定義(和值)。由於我有兩個以上的情況,似乎我不能使用tf.cond。我試圖用tf.case如下:Tesnorflow:無法使用帶輸入參數的tf.case

import tensorflow as tf 

#### 
EPSILON_DELTA_PHASE1 = 33e-4 
EPSILON_DELTA_PHASE2 = 2.5 
#### 
step = tf.placeholder(dtype=tf.float32, shape=None) 


def fn1(step): 
    return tf.constant([1.]) 

def fn2(step): 
    return tf.constant([1.+step*EPSILON_DELTA_PHASE1]) 

def fn3(step): 
    return tf.constant([1.+step*EPSILON_DELTA_PHASE2]) 

epsilon_n = tf.case(
     pred_fn_pairs=[ 
      (tf.less(step, 3e4), lambda step: fn1(step)), 
      (tf.less(step, 6e4), lambda step: fn2(step)), 
      (tf.less(step, 1e5), lambda step: fn3(step))], 
      default=lambda: tf.constant([1e5]), 
     exclusive=False) 

不過,我不斷收到此錯誤信息:

TypeError: <lambda>() missing 1 required positional argument: 'step' 

我試過如下:

epsilon_n = tf.case(
     pred_fn_pairs=[ 
      (tf.less(step, 3e4), fn1), 
      (tf.less(step, 6e4), fn2), 
      (tf.less(step, 1e5), fn3)], 
      default=lambda: tf.constant([1e5]), 
     exclusive=False) 

我仍相同的錯誤。 Tensorflow文檔中的示例重點討論了沒有將輸入參數傳遞給可調用函數的情況。我無法在互聯網上找到關於tf.case的足夠信息!請幫忙嗎?

回答

2

這裏有幾個你需要做的改變。 爲了保持一致性,您可以將所有返回值設置爲變量。

# Since step is a scalar, scalar shape [() or [], not None] much be provided 
step = tf.placeholder(dtype=tf.float32, shape=()) 


def fn1(step): 
    return tf.constant([1.]) 

# Here you need to use Variable not constant, since you are modifying the value using placeholder 
def fn2(step): 
    return tf.Variable([1.+step*EPSILON_DELTA_PHASE1]) 

def fn3(step): 
    return tf.Variable([1.+step*EPSILON_DELTA_PHASE2]) 

epsilon_n = tf.case(
    pred_fn_pairs=[ 
     (tf.less(step, 3e4), lambda : fn1(step)), 
     (tf.less(step, 6e4), lambda : fn2(step)), 
     (tf.less(step, 1e5), lambda : fn3(step))], 
     default=lambda: tf.constant([1e5]), 
    exclusive=False) 
+0

修正輕微錯字 –