1
我需要創建一個變量epsilon_n
,該變量基於當前的step
更改定義(和值)。由於我有兩個以上的情況,似乎我不能使用tf.cond
。我試圖用tf.case
如下:Tesnorflow:無法使用帶輸入參數的tf.case
import tensorflow as tf
####
EPSILON_DELTA_PHASE1 = 33e-4
EPSILON_DELTA_PHASE2 = 2.5
####
step = tf.placeholder(dtype=tf.float32, shape=None)
def fn1(step):
return tf.constant([1.])
def fn2(step):
return tf.constant([1.+step*EPSILON_DELTA_PHASE1])
def fn3(step):
return tf.constant([1.+step*EPSILON_DELTA_PHASE2])
epsilon_n = tf.case(
pred_fn_pairs=[
(tf.less(step, 3e4), lambda step: fn1(step)),
(tf.less(step, 6e4), lambda step: fn2(step)),
(tf.less(step, 1e5), lambda step: fn3(step))],
default=lambda: tf.constant([1e5]),
exclusive=False)
不過,我不斷收到此錯誤信息:
TypeError: <lambda>() missing 1 required positional argument: 'step'
我試過如下:
epsilon_n = tf.case(
pred_fn_pairs=[
(tf.less(step, 3e4), fn1),
(tf.less(step, 6e4), fn2),
(tf.less(step, 1e5), fn3)],
default=lambda: tf.constant([1e5]),
exclusive=False)
我仍相同的錯誤。 Tensorflow文檔中的示例重點討論了沒有將輸入參數傳遞給可調用函數的情況。我無法在互聯網上找到關於tf.case的足夠信息!請幫忙嗎?
修正輕微錯字 –