你能用 BatchNormalization 解释神经网络中的 Keras get_weights() 函数吗?

Posted

技术标签:

【中文标题】你能用 BatchNormalization 解释神经网络中的 Keras get_weights() 函数吗?【英文标题】:Can you explain Keras get_weights() function in a Neural Network with BatchNormalization? 【发布时间】:2019-11-26 23:05:00 【问题描述】:

当我在 Keras 中运行神经网络(没有 BatchNormalization)时,我了解 get_weights() 函数如何提供 NN 的权重和偏差。但是使用 BatchNorm 它会产生 4 个额外的参数,我假设 Gamma、Beta、Mean 和 Std。

当我保存这些值时,我尝试手动复制一个简单的 NN,但无法让它们产生正确的输出。有谁知道这些值是如何工作的?

No Batch Norm

With Batch Norm

【问题讨论】:

【参考方案1】:

我将举一个例子来解释 get_weights() 在简单的多层感知器 (MLP) 和带有 Batch Normalization (BN) 的 MLP 的情况下。

示例:假设我们正在研究 MNIST 数据集,并使用 2 层 MLP 架构(即 2 个隐藏层)。隐藏层 1 中的神经元数量为 392,隐藏层 2 中的神经元数量为 196。所以我们的 MLP 的最终架构将是 784 x 512 x 196 x 10

这里784是输入图像维度,10是输出层维度

Case1: MLP without Batch Normalization => 让我的模型名为 model_relu,它使用 ReLU 激活函数。现在在训练 model_relu 之后,我正在使用 get_weights(),这将返回一个大小为 6 的列表,如下面的屏幕截图所示。

get_weights() with simple MLP and without Batch Norm列表值如下:

(784, 392):隐藏层 1 的权重

(392,):与隐藏层 1 权重相关的偏差

(392, 196):隐藏层 2 的权重

(196,):与隐藏层 2 权重相关的偏差

(196, 10):输出层的权重

(10,):与输出层权重相关的偏差

Case2: MLP with Batch Normalization => 让我的模型名称为 model_batch,它也使用 ReLU 激活函数和 Batch Normalization。现在在训练 model_batch 之后,我正在使用 get_weights(),这将返回一个大小为 14 的列表,如下面的屏幕截图所示。

get_weights() with Batch Norm 列表值如下:

(784, 392): 隐藏层 1 的权重 (392,): 与隐藏层 1 权重相关的偏差

(392,) (392,) (392,) (392,):这四个参数分别是 gamma、beta、mean 和 std。大小为 392 的 dev 值,每个都与隐藏层 1 的 Batch Normalization 相关。

(392, 196):隐藏层 2 的权重

(196,): 与隐藏层 2 权重相关的偏差

(196,) (196,) (196,) (196,):这四个参数分别是 gamma、beta、running mean 和 std。 dev 大小为 196,每个都与隐藏层 2 的 Batch Normalization 相关。

(196, 10):输出层的权重

(10,):与输出层权重相关的偏差

所以,在case2中,如果你想获得隐藏层1、隐藏层2和输出层的权重,python代码可以是这样的:

wrights = model_batch.get_weights()      
hidden_layer1_wt = wrights[0].flatten().reshape(-1,1)     
hidden_layer2_wt = wrights[6].flatten().reshape(-1,1)     
output_layer_wt = wrights[12].flatten().reshape(-1,1)

希望这会有所帮助!

Ref: keras-BatchNormalization

【讨论】:

【参考方案2】:

给出的四个值是 gamma、beta、moving_mean 和移动标准差。可以在keras的源码里面查看

【讨论】:

以上是关于你能用 BatchNormalization 解释神经网络中的 Keras get_weights() 函数吗?的主要内容,如果未能解决你的问题,请参考以下文章

你能用python将代码/exe注入到进程中吗?

我无法理解 Iterable 类然后使用 .map 语法。你能用一种简单的语言来表达吗?

你能解释一下STA和MTA吗?

你能解释一下提供的例子中的分类报告(召回率和精度)吗?

你能用链接调用一个servlet吗?

liquibase:你能用 liquibase addColumn 指定“列后”吗?