我使用slim.batch_norm
從layers並試圖瞭解我的用例中的代碼流。它看起來像我決定是否使用_fused_batch_norm()
的邏輯,或者在我的情況下,如果輸入等級爲2,基類將只使用_fused_batch_norm()
。代碼描述聽起來像它應該在rank爲4時使用,而函數本身(_fused_batch_norm())支持4級,但邏輯似乎阻止調用它。下面是代碼片段展示什麼,我指的是:輸入等級爲4時使用哪種tensorflow批量規範代碼?
# Only use _fused_batch_norm (1) if fused is set True or if it is
# possible to use (currently it doesn't support batch weights,
# renorm, and the case when rank is neither 2 nor 4),
# and (2) if used with zero_debias_moving_mean, or an input shape of rank 2,
# or non-default updates_collections (not implemented in
# normalization_layers.BatchNormalization yet); otherwise use the fused
# implementation in normalization_layers.BatchNormalization.
inputs = ops.convert_to_tensor(inputs)
rank = inputs.get_shape().ndims
feature_supported = batch_weights is None and not renorm and rank in [2, 4]
possible_to_fuse = fused is None and feature_supported
if (fused or possible_to_fuse) and (
zero_debias_moving_mean or rank == 2 or
updates_collections is not ops.GraphKeys.UPDATE_OPS):
return _fused_batch_norm(...)
對於我的使用情況下,我有以下的參數都在默認設置:
batch_weights=None
fused=False
renorm=False
zero_debias_moving_mean=False
updates_collections=ops.GraphKeys.UPDATE_OPS
如果我輸入的是4級,它看起來像代碼將使用normalization_layers.BatchNormalization
中的融合實現我的理解邏輯是否正確?
這是預期的和適當的行爲?我在想,如果條件rank==2
實際上應該是rank in [2,4]
?如果後者是正確的,那麼這將是一個潛在的錯誤。如果原件是正確的,那麼爲什麼有rank in [2,4]
確定feature_supported
?