1

我想嘗試一些其他傳輸函數,除了TensorFlow的BasicRNNCell中的默認tanhTensorFlow:將BasicRNNCell的tanh改爲另一個op?

原來的實現是這樣的:

class BasicRNNCell(RNNCell): 
(...) 
def __call__(self, inputs, state, scope=None): 
    """Most basic RNN: output = new_state = tanh(W * input + U * state + B).""" 
    with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell" 
     output = tanh(linear([inputs, state], self._num_units, True)) 
    return output, output 

...我把它改爲:

class MyRNNCell(BasicRNNCell): 
(...) 
def __call__(self, inputs, state, scope=None): 
    """Most basic RNN: output = new_state = tanh(W * input + U * state + B).""" 
    with tf.variable_scope(scope or type(self).__name__): # "BasicRNNCell" 
     output = my_transfer_function(linear([inputs, state], self._num_units, True)) 
    return output, output 

更改vs.variable_scopetf.variable_scope,是成功的,但linear是>rnn_cell.py實現<,而不是在tf本身。

我該如何得到這個工作?

我是否必須完全重新實施linear? (我已經檢查了代碼,我想我也會遇到依賴關係問題......)

+0

我不熟悉張量流,但tensorflow.python.ops.rnn_cell不是內建/基類。您應該能夠擴展該類或將「新」方法「衝入」現有類,以便該類中的所有依賴項都可用。所以,我不知道你爲什麼希望將vsvariable_scope改爲tf.variable_scope。你能解釋一下你從tf.variable_scope中需要什麼嗎?例如:如何實現「my_transfer_function」? – user2133679

+0

'vs'不可見,因此運行腳本時會得到'全局名稱vs未定義'。這就是爲什麼我將'vs'改爲'tf'的原因,因爲'tf'中存在'variable_scope'方法。然而,'linear'是該文件中的一個本地函數,因此它會在運行我的腳本時抱怨'linear'缺少。因此,我的問題是我將如何滿足這種依賴性? – daniel451

回答

2

您不需要爲此更改張量流實現的代碼。

BasicRNNCell具有一個稱爲激活函數的參數。您只需簡單地將其從tf.tanh更改爲您想要的任何激活功能即可。

相關問題