PyTorch 半精度训练踩坑
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch 半精度训练踩坑相关的知识,希望对你有一定的参考价值。
参考技术A 因为显卡显存不够,所以了解了一些PyTorch节省显存的方法:
拿什么拯救我的 4G 显卡 - OpenMMLab的文章 - 知乎
https://zhuanlan.zhihu.com/p/430123077
其中有一种方法叫做半精度训练:
PyTorch的自动混合精度(AMP) - Gemfield的文章 - 知乎
https://zhuanlan.zhihu.com/p/165152789
什么是半精度训练呢?
PyTorch中默认创建的tensor都是FloatTensor类型。而在PyTorch中,一共有10种类型的tensor:
所谓半精度训练,就是用torch.HalfTensor进行训练,以FP16的方式存储数据(本来是FP32),从而节省显存。
使用半精度训练的方式也很简单:
参考文章 pytorch 使用amp.autocast半精度加速训练
即使用 autocast + GradScaler
训练时loss出现了nan。
一步步向上追溯,发现T5.encoder就已经输出了nan。
Google了一下,发现并不是我一个人遇到这个问题:
https://github.com/huggingface/transformers/issues/4287
1. 取消半精度训练
简单直接,我就用了这种方式
2. 其他思路
参考: 解决pytorch半精度amp训练nan问题
以上是关于PyTorch 半精度训练踩坑的主要内容,如果未能解决你的问题,请参考以下文章