2017-06-29 36 views
0

我已經搜索並發現了類似的問題,但沒有一個看起來與我面臨的問題相同。我正在嘗試使用Theano後端實現一個Keras神經網絡(兩者都是最新的),它涉及一個採用圖層的一維輸出的Lambda圖層,並將它轉換爲一個n維向量, d輸出重複n次。添加具有不同尺寸輸入和輸出的Keras Lambda圖層時的廣播問題

我似乎運行到的問題是,在拉姆達層Keras似乎是期待的是,輸入具有相同的尺寸作爲輸出形狀我指定:

x=Input(shape=(2,)) 
V1=Dense(1)(x) 
V2=Lambda(lambda B : B[0,0]*K.ones((3,)),output_shape=(3,))(V1) 
model=Model(inputs=x,outputs=V2) 
rms = RMSprop() 
model.compile(loss='mse', optimizer=rms) 
model.predict(np.array([1,2]).reshape((1,2))) 

這給這錯誤:

--------------------------------------------------------------------------- 
ValueError        Traceback (most recent call last) 
<ipython-input-7-40a7e91d5963> in <module>() 
----> 1 model.predict(np.array([1,2]).reshape((1,2))) 

/Users/user/anaconda/envs/py35/lib/python3.5/site-packages/keras/engine /training.py in predict(self, x, batch_size, verbose) 
    1504   f = self.predict_function 
    1505   return self._predict_loop(f, ins, 
-> 1506         batch_size=batch_size, verbose=verbose) 
    1507 
    1508  def train_on_batch(self, x, y, 

/Users/user/anaconda/envs/py35/lib/python3.5/site-packages/keras/engine/training.py in _predict_loop(self, f, ins, batch_size, verbose) 
    1137 
    1138    for i, batch_out in enumerate(batch_outs): 
-> 1139     outs[i][batch_start:batch_end] = batch_out 
    1140    if verbose == 1: 
    1141     progbar.update(batch_end) 

ValueError: could not broadcast input array from shape (3) into shape (1) 

我知道有其他的方法,試圖做到這一點(K.repeat_elements),但這也給了我關於廣播錯誤消息。請注意,即使我刪除了B[0,0]*(這樣Lambda層完全不依賴於B),問題仍然存在。如果我將K.onesoutput_shape中的(3,)更改爲(1,),那麼它似乎工作。

據我所知,Lambda圖層應該能夠處理不同維度的輸入/輸出對,這是否正確?

回答

0

output_shape中,您不考慮批量大小。所以這是正確的:(3,)

但是在張量中,批量大小不會被忽略。您的表達式結果至少需要兩個維度:(Batch_size,3)。

此外,不要使用張量元素,使用整個張量。我還沒有發現它會是重要的或有用的使用單獨的元素(因爲你應該做同樣的操作,以整批)

我建議你使用K.repeat_elements(B, rep=3, axis=-1)

+0

非常有幫助的情況下,謝謝。我認爲對於我來說混亂的一點是,在'output_shape'中,批量大小佔據了元組中的第二個點,而在lambda函數中定義'K.ones'張量時,批量大小在第一個時隙中。你能給我一個這個爲什麼嗎? 考慮到這一點,您的建議都可以工作(重塑'K.repeat_elements'張量後)。謝謝! – hughes

+0

不,批量大小始終佔據第一位。我不知道爲什麼'(3,)'中有一個逗號,但這並不意味着任何有用的東西。當你只有一個維度時,你只需要輸入逗號。如果有(無,3),則將其定義爲(3,)。如果有(無,3,4),則將其定義爲(3,4)。 - 這可能只是一個符號問題,也許有必要至少有一個逗號來創建一個元組? –

+0

當答案對你有用時,考慮將其標記爲回答:) - 這有助於其他用戶搜索答案時。 –