pytorch代码中的KL-divergence与公式有啥关系?

Posted

技术标签:

【中文标题】pytorch代码中的KL-divergence与公式有啥关系?【英文标题】:How is KL-divergence in pytorch code related to the formula?pytorch代码中的KL-divergence与公式有什么关系? 【发布时间】:2020-08-19 04:20:42 【问题描述】:

在 VAE 教程中,两个正态分布的 kl-散度定义为:

而在here、here和here等很多代码中,代码实现为:

 KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())

def latent_loss(z_mean, z_stddev):
    mean_sq = z_mean * z_mean
    stddev_sq = z_stddev * z_stddev
    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)

它们有什么关系?为什么代码中没有“tr”或“.transpose()”?

【问题讨论】:

这是由 Kingma (here) 在附录 B 中由 Kingma (here) 的原始 VAE 论文制定的。请注意,在第二个版本中还有一个额外的缩放比例,它使用 torch.mean 而不是 torch.sum这不是问题,因为缩放不会改变最佳点(尽管这可能意味着您需要不同的学习率)。 @jodag 非常有帮助,谢谢 @jodag 关于torch.sum和torch.mean,你说“这可能意味着你需要不同的学习率”,但是KL损失并不是唯一的损失项,loss=kl_loss+recon_loss,这是否意味着损失实际上是具有不同权重的加权和? 是的,如果您使用均值而不是总和,则 kl_loss 分量的权重将隐式低于原始公式,这可能会影响损失函数的最佳点,并可能影响最终结果。 【参考方案1】:

您发布的代码中的表达式假定 X 是一个不相关多元高斯随机变量。这在协方差矩阵的行列式中没有交叉项是显而易见的。因此均值向量和协方差矩阵的形式为

使用它,我们可以快速推导出原始表达式组件的以下等效表示

将这些替换回原来的表达式得到

【讨论】:

如果 sigma 和 mu 来自非高斯分布,那么最终表达式是否有效? @muammar 此表达式假定 X 中的条目是按 i.i.d 绘制的。来自高斯分布。如果 X 来自不同的分布,我怀疑该表达式是否有效,因为 KL 散度是分布的函数,而不仅仅是第一和第二时刻。 感谢您的回答@jodag,这让我更清楚了——尤其是KL divergence is a function of the distribution and not just the first and second moments

以上是关于pytorch代码中的KL-divergence与公式有啥关系?的主要内容,如果未能解决你的问题,请参考以下文章

概率分布之间的距离度量以及python实现

说话人识别损失函数的PyTorch实现与代码解读

PyTorch 中的 tensordot 以及 einsum 函数介绍

学习笔记Pytorch十二损失函数与反向传播

pytorch 中的常用矩阵操作

吐血整理:PyTorch项目代码与资源列表 | 资源下载