如何在 PyTorch 中添加自定义定位损失函数?
Posted
技术标签:
【中文标题】如何在 PyTorch 中添加自定义定位损失函数?【英文标题】:How to add a custom localization loss function in PyTorch? 【发布时间】:2021-10-03 23:00:27 【问题描述】:我有一个 PyTorch 网络,它使用 Wi-Fi RSS 数据预测设备的位置。所以输出层包含两个对应于x和y坐标的神经元。我想使用平均定位误差作为损失函数。
即。损失=均值(sqrt((x_predicted - X_real)^2 +(y_predicted - y_real)^2))
该等式计算预测位置与实际位置之间的误差距离。我怎样才能包含这个而不是 MSE?
【问题讨论】:
【参考方案1】:正如您在tutorial 中看到的,只需实现一个criterion
函数(您可以随意命名)并使用它:
def custom_loss(output, label):
return torch.mean(torch.sqrt(torch.sum((output - label)**2)))
以及代码中的(从链接教程中窃取):
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = custom_loss(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
HTH
【讨论】:
以上是关于如何在 PyTorch 中添加自定义定位损失函数?的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch 中自定义后向函数的损失 - 简单 MSE 示例中的爆炸损失