pytorch-lightning 中的正态分布采样

Posted

技术标签:

【中文标题】pytorch-lightning 中的正态分布采样【英文标题】:Normal distribution sampling in pytorch-lightning 【发布时间】:2020-12-18 23:19:50 【问题描述】:

Pytorch-Lightning 中,您通常不必指定 cuda 或 gpu。但是当我想使用torch.normal 创建一个高斯采样张量时,我得到了

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

那么,我该如何更改 torch.normal 才能使 pytorch-lightning 正常工作?由于我在 cpu 在 gpu 上使用不同机器上的代码

centers = data["centers"] #already on GPU... sometimes...

lights = torch.normal(0, 1, size=[100, 3])
lights += centers

【问题讨论】:

只需将它移动到与它将与之交互的其他张量之一相同的设备上。由于您没有发布任何代码,因此很难更具体地说明应该做什么。 【参考方案1】:

推荐的方法是lights = torch.normal(0, 1, size=[100, 3], device=self.device) 如果这是在闪电类里面。 您也可以这样做:lights = torch.normal(0, 1, size=[100, 3]).type_as(tensor),其中tensor 是 cuda 上的某个张量。

【讨论】:

以上是关于pytorch-lightning 中的正态分布采样的主要内容,如果未能解决你的问题,请参考以下文章

如何禁用 PyTorch-Lightning 记录器的日志记录?

Pytorch-Lightning 是不是具有多处理(或 Joblib)模块?

pytorch-lightning入门—— 初了解

PyTorch-lightning 模型在第一个 epoch 后内存不足

使用 pytorch-lightning 进行简单预测的示例

使用 pytorch-lightning 实现 Network in Network CNN 模型