Pytorch 预期类型 Long 但得到类型 int
Posted
技术标签:
【中文标题】Pytorch 预期类型 Long 但得到类型 int【英文标题】:Pytorch expected type Long but got type int 【发布时间】:2019-10-16 20:58:58 【问题描述】:我恢复了一个错误
Expected object of scalar type Long but got scalar type Int for argument #3 'index'
这是来自这一行。
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
我不确定该怎么做,因为我尝试使用多个位置将其转换为 long。我试着放一个
.long
最后以及将 dtype 设置为 torch.long 仍然不起作用。
与此非常相似,但他没有做任何事情来得到答案 "Expected Long but got Int" while running PyTorch script
我已经更改了很多代码,这是我的最后一个版本,但现在给了我同样的问题。
def forward(self, inputs, targets):
"""
Args:
inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
targets: ground truth labels with shape (num_classes)
"""
log_probs = self.logsoftmax(inputs)
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
if self.use_gpu: targets = targets.to(torch.device('cuda'))
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (- targets * log_probs).mean(0).sum()
return loss
【问题讨论】:
您提到的问题大约是 5 个月大。这次他们不太可能只是坐在那里。对问题发表评论,看看他们是否找到了解决方案 python 3 不再支持 Long,所以请尝试使用 int(var_name) 【参考方案1】:您的索引参数(即targets.unsqueeze(1).data.cpu()
)的数据类型必须是torch.int64
。
(错误信息有点混乱:torch.long
不存在。但 PyTorch 内部的“Long”表示 int64)。
【讨论】:
我该怎么做?我知道在java中它像这样 int(x);但我用python尝试了类似的东西,但没有用。有什么建议?我也尝试将 dtype= int64 添加到它,但也没有用。 "我也尝试添加 dtype= int64 ,但也没有用。"你能显示你使用的确切代码吗? targets = torch.int64(torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)) targets = torch.zeros(log_probs.size()).scatter_(1, torch.int64(targets.unsqueeze(1).data.cpu()), 1) targets = torch.int64(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)【参考方案2】:targets = torch.zeros(log_probs.size()).scatter_(1, (targets.unsqueeze(1).data.cpu()).long(), 1)
【讨论】:
你好!感谢分享答案!对于未来,为答案添加解释肯定会有所帮助! :)以上是关于Pytorch 预期类型 Long 但得到类型 int的主要内容,如果未能解决你的问题,请参考以下文章
我更改了标量类型 float 的预期对象,但在 Pytorch 中仍然得到 Long
pytorch RuntimeError: 标量类型 Double 的预期对象,但得到标量类型 Float
RuntimeError: 标量类型 Long 的预期对象,但参数 #2 'mat2' 的标量类型 Float 如何解决?
如何修复pytorch'RuntimeError:类型为torch.cuda.LongTensor但发现类型为torch.LongTensor的预期对象'
RuntimeError:预期的标量类型 Long 但发现 Float
通过 DataLoader (PyTorch) 迭代:RuntimeError: 标量类型 unsigned char 的预期对象但序列元素 9 的标量类型浮点数