Keras 的 BatchNormalization 和 PyTorch 的 BatchNorm2d 的区别?

Posted

技术标签:

【中文标题】Keras 的 BatchNormalization 和 PyTorch 的 BatchNorm2d 的区别?【英文标题】:Difference between Keras' BatchNormalization and PyTorch's BatchNorm2d? 【发布时间】:2020-05-21 14:28:21 【问题描述】:

我有一个在 Keras 和 PyTorch 中实现的微型 CNN 示例。当我打印两个网络的摘要时,可训练参数的总数相同,但参数总数和批量标准化的参数数不匹配。

这是 Keras 中的 CNN 实现:

inputs = Input(shape = (64, 64, 1)). # Channel Last: (NHWC)

model = Conv2D(filters=32, kernel_size=(3, 3), padding='SAME', activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 1))(inputs)
model = BatchNormalization(momentum=0.15, axis=-1)(model)
model = Flatten()(model)

dense = Dense(100, activation = "relu")(model)
head_root = Dense(10, activation = 'softmax')(dense)

上面的模型打印的总结是:

Model: "model_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_9 (InputLayer)         (None, 64, 64, 1)         0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 64, 64, 32)        320       
_________________________________________________________________
batch_normalization_2 (Batch (None, 64, 64, 32)        128       
_________________________________________________________________
flatten_3 (Flatten)          (None, 131072)            0         
_________________________________________________________________
dense_11 (Dense)             (None, 100)               13107300  
_________________________________________________________________
dense_12 (Dense)             (None, 10)                1010      
=================================================================
Total params: 13,108,758
Trainable params: 13,108,694
Non-trainable params: 64
_________________________________________________________________

下面是 PyTorch 中相同模型架构的实现:

# Image format: Channel first (NCHW) in PyTorch
class CustomModel(nn.Module):
def __init__(self):
    super(CustomModel, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=1),
        nn.ReLU(True),
        nn.BatchNorm2d(num_features=32),
    )
    self.flatten = nn.Flatten()
    self.fc1 = nn.Linear(in_features=131072, out_features=100)
    self.fc2 = nn.Linear(in_features=100, out_features=10)

def forward(self, x):
    output = self.layer1(x)
    output = self.flatten(output)
    output = self.fc1(output)
    output = self.fc2(output)
    return output

以下是上述模型的总结输出:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 64, 64]             320
              ReLU-2           [-1, 32, 64, 64]               0
       BatchNorm2d-3           [-1, 32, 64, 64]              64
           Flatten-4               [-1, 131072]               0
            Linear-5                  [-1, 100]      13,107,300
            Linear-6                   [-1, 10]           1,010
================================================================
Total params: 13,108,694
Trainable params: 13,108,694
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 4.00
Params size (MB): 50.01
Estimated Total Size (MB): 54.02
----------------------------------------------------------------

正如您在上面的结果中看到的,Keras 中的批量标准化比 PyTorch 具有更多的参数(准确地说是 2 倍)。那么上述 CNN 架构有什么区别呢?如果它们是等价的,那么我在这里缺少什么?

【问题讨论】:

【参考方案1】:

Keras 将许多将在层中“保存/加载”的东西视为参数(权重)。

虽然这两种实现自然具有批次的累积“均值”和“方差”,但这些值无法通过反向传播进行训练。

尽管如此,这些值每批都会更新,Keras 将它们视为不可训练的权重,而 PyTorch 只是将它们隐藏起来。此处的“不可训练”一词的意思是“通过反向传播不可训练”,但并不意味着这些值被冻结。

对于BatchNormalization 层,它们总共是 4 组“权重”。考虑到选定的轴(默认 = -1,层大小 = 32)

scale (32) - 可训练 offset (32) - 可训练 accumulated means (32) - 不可训练,但每批次更新 accumulated std (32) - 不可训练,但每批次更新

在 Keras 中这样做的好处是,当您保存图层时,您还可以保存平均值和方差值,就像您自动保存图层中的所有其他权重一样。当你加载图层时,这些权重会一起加载。

【讨论】:

以上是关于Keras 的 BatchNormalization 和 PyTorch 的 BatchNorm2d 的区别?的主要内容,如果未能解决你的问题,请参考以下文章

keras与tensorflow.python.keras - 使用哪一个?

keras是啥

为啥keras安装以后导入失败?

keras 与 tensorflow.python.keras - 使用哪一个?

Tensorflow+Keras用Tensorflow.keras的方法替代keras.layers.merge

Tensorflow+Keras用Tensorflow.keras的方法替代keras.layers.merge