使用`amp`进行GPU运算优化的学习笔记
Posted songyuc
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用`amp`进行GPU运算优化的学习笔记相关的知识,希望对你有一定的参考价值。
1 AMP训练:torch.cuda.amp
示例代码:
# amp依赖Tensor core架构,所以model参数必须是cuda tensor类型
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# GradScaler对象用来自动做梯度缩放
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# 在autocast enable 区域运行forward
with autocast():
# model做一个FP16的副本,forward
output = model(input)
loss = loss_fn(output, target)
# 用scaler,scale loss(FP16),backward得到scaled的梯度(FP16)
scaler.scale(loss).backward()
# scaler 更新参数,会先自动unscale梯度
# 如果有nan或inf,自动跳过
scaler.step(optimizer)
# scaler factor更新
scaler.update()
2 使用amp
和GradientAccumulation联合进行优化
scaler = GradScaler()
for epoch in epochs:
for i, (input, target) in enumerate(data):
with autocast():
output = model(input)
loss = loss_fn(output, target)
loss = loss / iters_to_accumulate # 看看这个是否可以省略
# Accumulates scaled gradients.
scaler.scale(loss).backward()
if (i + 1) % iters_to_accumulate == 0:
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
3 Troubleshooting
3.1 RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
运行时出现错误提示:
RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are safe to autocast.
由提示信息可知,torch
规定无法在autocast
作用域中使用nn.BCELoss(reduction="none")
,于是,需要在代码中单独声明在计算BCE损失时不使用autocast
,示例代码如下:
with autocast(enabled=False):
bce = self.bce_loss(output_map.float(), target_map.float())
Note:
在autocast(enabled=False)
作用域中引用的tensor需要使用其float()
版本;请参考torch
官方示例amp_force_float32。
以上是关于使用`amp`进行GPU运算优化的学习笔记的主要内容,如果未能解决你的问题,请参考以下文章
CS231n 2017 学习笔记03——损失函数与参数优化 Loss Functions and Optimization
斯坦福CS231n—深度学习与计算机视觉----学习笔记 课时8&&9