2017-06-12 72 views
1

我想使用Keras-rl庫(https://github.com/matthiasplappert/keras-rl/blob/master/examples/dqn_atari.py)中的DQN代碼和3通道圖像(無灰度轉換)。3通道圖像的卷積層

如何更改代碼來做到這一點?我試圖刪除此行:img = img.resize(INPUT_SHAPE).convert('L') # resize and convert to grayscale但我有尺寸誤差..

我加入from keras import backend as K K.set_image_dim_ordering('th') 我更改網絡achitecture:

model = Sequential() 
model.add(Permute((3, 1, 2),input_shape=(200,200,3))) 
model.add(Lambda(lambda a: a/255.0)) 
model.add(Conv2D(32, (8, 8), strides=(2, 2), activation=activation)) 
model.add(Conv2D(32, (4, 4), strides=(2, 2), activation=activation)) 
model.add(Conv2D(32, (3, 3), strides=(2, 2), activation=activation)) 
model.add(Conv2D(32, (2, 2), strides=(1, 1), activation=activation)) 
model.add(TimeDistributed(Flatten())) 
model.add(LSTM(128)) 
for i in xrange(nb_layers): 
    model.add(Dense(hidden_size, activation=activation)) 
model.add(Dense(env.action_space.n + 1)) 
model.add(Lambda(lambda a: K.expand_dims(a[:, 0], axis=-1) + a[:, 1:], output_shape=(env.action_space.n,))) 
print(model.summary()) 

但在運行時:

ValueError: Error when checking : expected permute_1_input to have 4 dimensions, but got array with shape (1, 1, 200, 200, 3)

如何處理3通道圖像?

+0

用來訓練模型的圖像陣列的形狀是什麼? –

+0

圖像的形狀是(84,84,3) –

+0

這是不可能的,因爲您的input_Shape是(200,200,3)。因此,你的圖像數組的輸入應該是(x,200,200,3)。 x代表圖像的數量。 –

回答

0

我在發佈的代碼中看不到任何錯誤,但我的猜測是,當您刪除調整行時,意外地在上一行初始化圖像數組的末尾添加了,。 Python會將img = abc(),解析爲img = (abc(),)。該錯誤與您的5維形狀相匹配。