PyTorch:_thnn_nll_loss_forward 未针对类型 torch.LongTensor 实现
Posted
技术标签:
【中文标题】PyTorch:_thnn_nll_loss_forward 未针对类型 torch.LongTensor 实现【英文标题】:PyTorch: _thnn_nll_loss_forward is not implemented for type torch.LongTensor 【发布时间】:2019-09-18 17:11:25 【问题描述】:当我尝试使用 PyTorch 创建模型时,当我尝试实现损失函数 nll_loss
时,它会抛出以下错误
RuntimeError: _thnn_nll_loss_forward is not implemented for type torch.LongTensor
我创建的拟合函数是:
for epoch in tqdm_notebook(range(1, epochs+1)):
for batch_idx, (data, targets) in enumerate(train_loader):
optimizer.zero_grad()
net.float()
output = net(data)
output_x = output.argmax(dim=2) #to convert (64,50,43) -> (64, 50)
loss = F.nll_loss(output_x, targets)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train epochs: [/ (:.0f%)]\tLoss: :.6f'.format(
epoch, batch_idx*len(data), len(ds.data),
100.*batch_idx / len(ds), loss.item()
))
输出和目标的形状是 (64, 50) 并且两者的 dtype 都是 torch.int64
。
【问题讨论】:
【参考方案1】:查看description 的F.nll_loss
。它期望获得的输入不是预测的argmax
(类型torch.long
),而是完整的64x50x43 预测向量(类型torch.float
)。请注意,您提供给F.nll_loss
的预测确实比您提供的地面实况目标具有额外的维度。
在您的情况下,只需删除 argmax:
loss = F.nll_loss(output, targets)
【讨论】:
我已经尝试删除 argmax,然后我收到此错误RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #4 'mat1'
。
@thanatoz 似乎您将net
转换为float
,但由于某种原因,您的output
类型为torch.double
。
你有什么建议。我应该如何解决这个问题?我发现nn.GRU(input_size=100, hidden_size=50, dropout=0.5, bidirectional=True, num_layers=2, batch_first=True)
对此负责。
@thanatoz 确保您的data
和隐藏状态都是torch.float
类型。请注意,x.to(torch.float)
或 x.float()
不是就地操作,您需要 x = x.to(torch.float)
才能将 x
设为浮点数。【参考方案2】:
看起来您正在处理具有43
类的分类任务,使用的批量大小为64
,“序列长度”为50
。
如果是这样,我相信您对使用argmax()
或F.log_softmax
有点困惑。正如 Shai 给出的参考,鉴于 output
是 logit 值,您可以使用:
output_x = F.log_softmax(output, dim=2)
loss = F.nll_loss(output_x, targets)
这是使用nll_loss
的正确方法,或者如果你不想做log_softmax
你自己,你可以改用nn.CrossEntropyLoss
。
【讨论】:
是的,你没看错。我确实尝试了您的建议,但遇到了以下错误RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #4 'mat1'
。我做错了什么?
好的,谢谢大卫。以上是关于PyTorch:_thnn_nll_loss_forward 未针对类型 torch.LongTensor 实现的主要内容,如果未能解决你的问题,请参考以下文章
pytorch PyTorch 1.1.0 源码解析--运行机制
pytorch土堆pytorch教程学习torchvision 中的数据集的使用