您應該通過set_weights方法將一個numpy數組傳遞給卷積圖層。
請記住,卷積層的權重不僅是每個單獨過濾器的權重,還有偏差。所以如果你想設置你的權重,你需要添加一個額外的維度。
例如,如果你想設置一個1x3x3過濾器與所有的權重,除了中央零元素,你應該讓:
w = np.asarray([
[[[
[0,0,0],
[0,2,0],
[0,0,0]
]]]
])
然後將其設置。
有關代碼可以運行:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import numpy as np
np.random.seed(1234)
from keras.layers import Input
from keras.layers.convolutional import Convolution2D
from keras.models import Model
print("Building Model...")
inp = Input(shape=(1,None,None))
output = Convolution2D(1, 3, 3, border_mode='same', init='normal',bias=False)(inp)
model_network = Model(input=inp, output=output)
print("Weights before change:")
print (model_network.layers[1].get_weights())
w = np.asarray([
[[[
[0,0,0],
[0,2,0],
[0,0,0]
]]]
])
input_mat = np.asarray([
[[
[1.,2.,3.],
[4.,5.,6.],
[7.,8.,9.]
]]
])
model_network.layers[1].set_weights(w)
print("Weights after change:")
print(model_network.layers[1].get_weights())
print("Input:")
print(input_mat)
print("Output:")
print(model_network.predict(input_mat))
嘗試在卷積fillter改變中心元件(在實施例2)。
代碼的作用:
起初建立一個模型。
inp = Input(shape=(1,None,None))
output = Convolution2D(1, 3, 3, border_mode='same', init='normal',bias=False)(inp)
model_network = Model(input=inp, output=output)
打印原來的權重(初始化符合正態分佈時,init = '正常')
print (model_network.layers[1].get_weights())
創建你想要的體重張量w和一些輸入input_mat
w = np.asarray([
[[[
[0,0,0],
[0,2,0],
[0,0,0]
]]]
])
input_mat = np.asarray([
[[
[1.,2.,3.],
[4.,5.,6.],
[7.,8.,9.]
]]
])
設定權重並打印出來
model_network.layers[1].set_weights(w)
print("Weights after change:")
print(model_network.layers[1].get_weights())
最後,用它來生成與輸出預測(預測自動編譯模型)
print(model_network.predict(input_mat))
輸出示例:
Using Theano backend.
Building Model...
Weights before change:
[array([[[[ 0.02357176, -0.05954878, 0.07163535],
[-0.01563259, -0.03602944, 0.04435815],
[ 0.04297942, -0.03182618, 0.00078482]]]], dtype=float32)]
Weights after change:
[array([[[[ 0., 0., 0.],
[ 0., 2., 0.],
[ 0., 0., 0.]]]], dtype=float32)]
Input:
[[[[ 1. 2. 3.]
[ 4. 5. 6.]
[ 7. 8. 9.]]]]
Output:
[[[[ 2. 4. 6.]
[ 8. 10. 12.]
[ 14. 16. 18.]]]]
哦,謝謝!我不清楚這一點。該文檔沒有說明權重形狀的確切要求。感謝你的例子! – displayname
僅供參考:https://github.com/fchollet/keras/issues/1671 – maz