2017-03-01 126 views
4

我有以下代碼。與BatchNormalization圖層相關的參數數量是多少?

x = keras.layers.Input(batch_shape = (None, 4096)) 
hidden = keras.layers.Dense(512, activation = 'relu')(x) 
hidden = keras.layers.BatchNormalization()(hidden) 
hidden = keras.layers.Dropout(0.5)(hidden) 
predictions = keras.layers.Dense(80, activation = 'sigmoid')(hidden) 
mlp_model = keras.models.Model(input = [x], output = [predictions]) 
mlp_model.summary() 

而這是模型總結:

____________________________________________________________________________________________________ 
Layer (type)      Output Shape   Param #  Connected to      
==================================================================================================== 
input_3 (InputLayer)    (None, 4096)   0            
____________________________________________________________________________________________________ 
dense_1 (Dense)     (None, 512)   2097664  input_3[0][0]      
____________________________________________________________________________________________________ 
batchnormalization_1 (BatchNorma (None, 512)   2048  dense_1[0][0]      
____________________________________________________________________________________________________ 
dropout_1 (Dropout)    (None, 512)   0   batchnormalization_1[0][0]  
____________________________________________________________________________________________________ 
dense_2 (Dense)     (None, 80)   41040  dropout_1[0][0]     
==================================================================================================== 
Total params: 2,140,752 
Trainable params: 2,139,728 
Non-trainable params: 1,024 
____________________________________________________________________________________________________ 

其輸入爲BatchNormalization(BN)層的尺寸爲512根據Keras documentation,輸出爲BN層的形狀是相同輸入是512.

那麼與BN層相關的參數數量是多少?

+0

有什麼理由不滿意這個問題? –

回答

5

Keras中的批處理標準化實現this paper

正如你可以在那裏閱讀的那樣,爲了使批處理標準化在訓練過程中工作,他們需要跟蹤每個標準化尺寸的分佈。要這樣做,因爲默認情況下你在mode=0之內,所以它們在上一層計算每個特徵的4個參數。這些參數確保您正確傳播和反向傳播信息。

所以4*512 = 2048,這應該回答你的問題。

1

這些2048參數實際上是[gamma weights, beta weights, moving_mean(non-trainable), moving_variance(non-trainable)],每個參數都有512個元素(輸入層的大小)。

相關問題