检测 NaN 的 Pytorch 操作
Posted
技术标签:
【中文标题】检测 NaN 的 Pytorch 操作【英文标题】:Pytorch Operation to detect NaNs 【发布时间】:2018-06-17 21:18:15 【问题描述】:
是否有 Pytorch 内部程序来检测张量中的NaN
s? Tensorflow 有 tf.is_nan
和 tf.check_numerics
操作...... Pytorch 有类似的东西吗?我在文档中找不到类似的东西......
我正在专门寻找 Pytorch 内部例程,因为我希望这在 GPU 和 CPU 上都发生。这不包括基于 numpy 的解决方案(如 np.isnan(sometensor.numpy()).any()
)...
【问题讨论】:
这可能会有所帮助:x.isnan().any()
【参考方案1】:
你总是可以利用nan != nan
:
>>> x = torch.tensor([1, 2, np.nan])
tensor([ 1., 2., nan.])
>>> x != x
tensor([ 0, 0, 1], dtype=torch.uint8)
在 pytorch 0.4 中还有 torch.isnan
:
>>> torch.isnan(x)
tensor([ 0, 0, 1], dtype=torch.uint8)
【讨论】:
我可以确认它也可以在 GPU 上运行。.any()
然后将其简化为 Python 布尔值。谢谢:-)
哇!我不知道nan != nan
。谢谢!【参考方案2】:
从 PyTorch 0.4.1 开始,有 detect_anomaly
上下文管理器,它会在反向传播的所有步骤之间自动插入等同于 assert not torch.isnan(grad).any()
的断言。在向后传递过程中出现问题时,它非常有用。
【讨论】:
【参考方案3】:正如@cleros 在对@nemo 答案的评论中所建议的那样,您可以使用any()
运算符将其作为布尔值:
torch.isnan(your_tensor).any()
【讨论】:
【参考方案4】:如果任何值为 nan,则为真:
torch.any(tensor.isnan())
如果都是 nan,则为真:
torch.all(tensor.isnan())
【讨论】:
【参考方案5】:如果你想直接在张量上调用它:
import torch
x = torch.randn(5, 4)
print(x.isnan().any())
出来:
import torch
x = torch.randn(5, 4)
print(x.isnan().any())
tensor(False)
【讨论】:
以上是关于检测 NaN 的 Pytorch 操作的主要内容,如果未能解决你的问题,请参考以下文章