PyTorch 半精度训练踩坑

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch 半精度训练踩坑相关的知识,希望对你有一定的参考价值。

参考技术A

因为显卡显存不够,所以了解了一些PyTorch节省显存的方法:
拿什么拯救我的 4G 显卡 - OpenMMLab的文章 - 知乎
https://zhuanlan.zhihu.com/p/430123077

其中有一种方法叫做半精度训练:
PyTorch的自动混合精度(AMP) - Gemfield的文章 - 知乎
https://zhuanlan.zhihu.com/p/165152789

什么是半精度训练呢?
PyTorch中默认创建的tensor都是FloatTensor类型。而在PyTorch中,一共有10种类型的tensor:

所谓半精度训练,就是用torch.HalfTensor进行训练,以FP16的方式存储数据(本来是FP32),从而节省显存。

使用半精度训练的方式也很简单:
参考文章 pytorch 使用amp.autocast半精度加速训练

即使用 autocast + GradScaler

训练时loss出现了nan。
一步步向上追溯,发现T5.encoder就已经输出了nan。
Google了一下,发现并不是我一个人遇到这个问题:
https://github.com/huggingface/transformers/issues/4287

1. 取消半精度训练
简单直接,我就用了这种方式
2. 其他思路

参考: 解决pytorch半精度amp训练nan问题

以上是关于PyTorch 半精度训练踩坑的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch AMP——自动混合精度训练

使用 cpu 与 gpu 进行训练的 pytorch 模型精度之间的巨大差异

pytorch量化感知训练(QAT)示例---ResNet

Pytorch自动混合精度(AMP)的使用总结

他山之石在C++平台上部署PyTorch模型流程+踩坑实录

Pytorch自动混合精度(AMP)介绍与使用 - autocast和Gradscaler