2017-03-27 81 views
1

我有以下功能:Keras lambda函數積mistmatch

def transpose_dot(vects): 
    x, y = vects 
    # <x,x> + <y,y> - 2<x,y> 

    return K.dot(x, K.transpose(y)) 

當嘗試使用keras評估它,它的工作原理

x = K.variable(np.array(np_x)) 
y = K.variable(np.array(np_x)) 
obj = transpose_dot 
objective_output = obj((x, y)) 
print('-----------------') 
print (K.eval(objective_output)) 

結果有:

[[ 1. 1. 1. 2.] 
[ 1. 2. 2. 4.] 
[ 1. 2. 2. 4.] 
[ 2. 4. 4. 8.] 

,當試圖使用它作爲的功能層它不起作用。

np_x = [[1, 0], [1, 1], [1, 1], [2, 2]] 
features = np.array([np_x]) 
test_input = Input(shape=np.array(np_x).shape) 
dot_layer= Lambda(transpose_dot, output_shape=(4,4))([test_input, test_input]) 
x = Model(inputs=test_input, outputs=dot_layer) 
x.predict(features, batch_size=1) 

self.fn() if output_subset is None else\ 
ValueError: Shape mismatch: x has 2 cols (and 4 rows) but y has 4 rows (and 2 cols) 
Apply node that caused the error: Dot22(Reshape{2}.0, Reshape{2}.0) 
Toposort index: 11 
Inputs types: [TensorType(float32, matrix), TensorType(float32, matrix)] 
Inputs shapes: [(4, 2), (4, 2)] 
Inputs strides: [(8, 4), (8, 4)] 
Inputs values: ['not shown', 'not shown'] 
Outputs clients: [[Reshape{4}(Dot22.0, MakeVector{dtype='int64'}.0)]] 

任何想法,我在這裏失蹤的結果?

編輯:錯誤消息

+0

什麼是錯誤信息?你有一個)在Lambda行中失蹤... –

+0

@NassimBen,我添加了錯誤信息,基本上它抱怨形狀,但'x有2列(和4行),但y有4行(和2列) ' – oak

回答

0

的 新增輸出與人的幫助下,在https://github.com/fchollet/keras/,我發現我的錯誤。 函數期望得到(n,m)。但是,當使用Lambda函數時,它期望得到(樣本,n,m)。