了解 pytorch 中正态分布的 log_prob
Posted
技术标签:
【中文标题】了解 pytorch 中正态分布的 log_prob【英文标题】:Understanding log_prob for Normal distribution in pytorch 【发布时间】:2020-07-01 01:11:48 【问题描述】:我目前正在尝试从具有连续动作空间的 openAi 健身房环境中解决 Pendulum-v0。因此,我需要使用正态分布来对我的操作进行采样。我不明白的是使用它时 log_prob 的尺寸:
import torch
from torch.distributions import Normal
means = torch.tensor([[0.0538],
[0.0651]])
stds = torch.tensor([[0.7865],
[0.7792]])
dist = Normal(means, stds)
a = torch.tensor([1.2,3.4])
d = dist.log_prob(a)
print(d.size())
我期待一个大小为 2 的张量(每个操作一个 log_prob),但它输出一个大小为 (2,2) 的张量。
但是,当对离散环境使用分类分布时,log_prob 具有预期的大小:
logits = torch.tensor([[-0.0657, -0.0949],
[-0.0586, -0.1007]])
dist = Categorical(logits = logits)
a = torch.tensor([1, 1])
print(dist.log_prob(a).size())
给我一个尺寸(2)的张量。
为什么正态分布的 log_prob 大小不同?
【问题讨论】:
我建议你提供一个Minimal, Reproducible Example,即一个可以执行的简单程序,以便我们验证你所描述的行为,而不是程序的截图! 我用代码编辑了我的问题 PyTorch 的文档很差,所以我完全理解你。无论如何,这个文档页面https://pytorch.org/docs/stable/distributions.html 说 PyTorch 分发模块遵循与 TensorFlow Probability 相同的设计。如果确实如此,那么您可以尝试查看 TFP 的文档。我目前正在使用 TFP,我可能能够回答这个问题,但稍后。如果您同时没有收到回复,请稍后联系我。 好的,谢谢您的宝贵时间 【参考方案1】:如果查看torch.distributions.Normal的source code,找到log_prob(value)函数的定义,可以看出计算的主要部分是:
return -((value - self.loc) ** 2) / (2 * var) - some other part
其中 value 是一个变量,其中包含您要为其计算对数概率的值(在您的情况下为 a),self.loc 是分布的平均值(在您的情况下,means),var 是方差,即标准差的平方(在您的情况下,stds**2)。可以看出,这确实是normal distribution的概率密度函数的对数,减去上面我没有写的一些常数和标准差的对数。
在第一个示例中,您将 means 和 stds 定义为列向量,而将 values 定义为行向量
means = torch.tensor([[0.0538],
[0.0651]])
stds = torch.tensor([[0.7865],
[0.7792]])
a = torch.tensor([1.2,3.4])
但是从列向量中减去行向量,代码在 Python 中的 value - self.loc 中会给出一个矩阵(试试!),因此您获得的结果是 log_prob 的值对于您定义的两个分布中的每一个以及 a 中的每个变量。
如果您想获得没有交叉项的 log_prob,则一致地定义变量,即,要么
means = torch.tensor([[0.0538],
[0.0651]])
stds = torch.tensor([[0.7865],
[0.7792]])
a = torch.tensor([[1.2],[3.4]])
或
means = torch.tensor([0.0538,
0.0651])
stds = torch.tensor([0.7865,
0.7792])
a = torch.tensor([1.2,3.4])
这就是您在第二个示例中的做法,这就是您获得预期结果的原因。
【讨论】:
以上是关于了解 pytorch 中正态分布的 log_prob的主要内容,如果未能解决你的问题,请参考以下文章
将 Z 分数(Z 值,标准分数)转换为 Python 中正态分布的 p 值