2016-05-24 27 views
1

我有一個奇怪的錯誤,在編譯Theano中的掃描運算符時我無法理解。 當outputs_info與最後的尺寸等於一個初始化的,我得到這個錯誤:當輸出的最後尺寸等於1時,Theano中掃描的奇怪行爲

TypeError: ('The following error happened while compiling the node', forall_inplace,cpu, 
scan_fn}(TensorConstant{4}, IncSubtensor{InplaceSet;:int64:}.0, <TensorType(float32, vector)>), 
'\n', "Inconsistency in the inner graph of scan 'scan_fn' : an input and an output are 
associated with the same recurrent state and should have the same type but have type 
'TensorType(float32, (True,))' and 'TensorType(float32, vector)' respectively.") 

而我如果尺寸設置爲任何大於一沒有得到任何錯誤。

這個錯誤發生在gpu和cpu目標上,其中theano 0.7,0.8.0和0.8.2。

這裏是一段代碼重現該錯誤:

import theano 
import theano.tensor as T 
import numpy as np 

def rec_fun(prev_output, bias):   
    return prev_output + bias 

n_steps = 4 

# with state_size>1, compilation runs smoothly 
state_size = 2 

bias = theano.shared(np.ones((state_size),dtype=theano.config.floatX)) 
(outputs, updates) = theano.scan(fn=rec_fun, 
           sequences=[], 
           outputs_info=T.zeros([state_size,]), 
           non_sequences=[bias], 
           n_steps=n_steps 
          ) 
print outputs.eval() 

# with state_size==1, compilation fails 
state_size = 1 

bias = theano.shared(np.ones((state_size),dtype=theano.config.floatX)) 
(outputs, updates) = theano.scan(fn=rec_fun, 
           sequences=[], 
           outputs_info=T.zeros([state_size,]), 
           non_sequences=[bias], 
           n_steps=n_steps 
          ) 
# compilation fails here 
print outputs.eval() 

彙編具有依賴於「state_size」,從而不同的行爲。 是否有解決方案來處理case_size == 1和state_size> 1?

回答

0

更改

outputs_info=T.zeros([state_size,]) 

outputs_info=T.zeros_like(bias) 

使得它的state_size == 1的情況下正常工作。

次要的解釋和不同的解決方案

所以我注意到這兩種情況之間的重要區別。 在這兩種情況下,都將這些代碼完全添加到偏差宣告行之後。

bias = .... 
print bias.broadcastable 
print T.zeros([state_size,]).broadcastable 

的結果是

因爲你的代碼工作

(False,) 
(False,) 

而對於它似乎打破

(False,) 
(True,) 

到底發生了第二種情況中,第一種情況是當你添加了兩個相同維度(偏差和T.zeros)但張寬不同的張量castable模式,結果繼承的模式是偏離的模式。這最終導致theano錯誤地認定它們不是同一類型。

T.zeros_like適用,因爲它使用bias變量來生成零張量。

另一種方式來解決你的問題是要改變廣播模式,像這樣

outputs_info=T.patternbroadcast(T.zeros([state_size,]), (False,)), 
+0

謝謝!這似乎是一個很好的解決方法,即使在我的特殊情況下,ouputs_info事實上依賴於幾個形狀參數,它使得代碼不易讀。我想解釋爲什麼直接輸出信息output_info = T.zeros([1,])解決方案失敗... –

+0

我編輯了答案,以便它包含解釋和另一個解決方案,以幫助您瞭解問題 –

+0

謝謝非常感謝您的幫助!現在我也更清楚了。顯式設置廣播模式使代碼更具可讀性。謝謝 ! –