检测 NaN 的 Pytorch 操作

Posted

技术标签:

【中文标题】检测 NaN 的 Pytorch 操作【英文标题】:Pytorch Operation to detect NaNs 【发布时间】:2018-06-17 21:18:15 【问题描述】:

是否有 Pytorch 内部程序来检测张量中的NaNs? Tensorflow 有 tf.is_nantf.check_numerics 操作...... Pytorch 有类似的东西吗?我在文档中找不到类似的东西......

我正在专门寻找 Pytorch 内部例程,因为我希望这在 GPU 和 CPU 上都发生。这不包括基于 numpy 的解决方案(如 np.isnan(sometensor.numpy()).any())...

【问题讨论】:

这可能会有所帮助:x.isnan().any() 【参考方案1】:

你总是可以利用nan != nan:

>>> x = torch.tensor([1, 2, np.nan])
tensor([  1.,   2., nan.])
>>> x != x
tensor([ 0,  0,  1], dtype=torch.uint8)

在 pytorch 0.4 中还有 torch.isnan:

>>> torch.isnan(x)
tensor([ 0,  0,  1], dtype=torch.uint8)

【讨论】:

我可以确认它也可以在 GPU 上运行。 .any() 然后将其简化为 Python 布尔值。谢谢:-) 哇!我不知道nan != nan。谢谢!【参考方案2】:

从 PyTorch 0.4.1 开始,有 detect_anomaly 上下文管理器,它会在反向传播的所有步骤之间自动插入等同于 assert not torch.isnan(grad).any() 的断言。在向后传递过程中出现问题时,它非常有用。

【讨论】:

【参考方案3】:

正如@cleros 在对@nemo 答案的评论中所建议的那样,您可以使用any() 运算符将其作为布尔值:

torch.isnan(your_tensor).any()

【讨论】:

【参考方案4】:

如果任何值为 nan,则为真:

torch.any(tensor.isnan())

如果都是 nan,则为真:

torch.all(tensor.isnan())

【讨论】:

【参考方案5】:

如果你想直接在张量上调用它:

import torch

x = torch.randn(5, 4)
print(x.isnan().any())

出来:

import torch
x = torch.randn(5, 4)
print(x.isnan().any())
tensor(False)

【讨论】:

以上是关于检测 NaN 的 Pytorch 操作的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch学习----01

pytorch笔记: 处理inf和nan数值

为什么他们要来集智AI学园学习 PyTorch? | 早鸟福利最后一天

pytorch:“不支持多目标”错误消息

PyTorch学习笔记:简介与基础知识

PyTorch——Python 优先的深度学习框架|软件推荐