Pytorch 预期类型 Long 但得到类型 int

Posted

技术标签:

【中文标题】Pytorch 预期类型 Long 但得到类型 int【英文标题】:Pytorch expected type Long but got type int 【发布时间】:2019-10-16 20:58:58 【问题描述】:

我恢复了一个错误

 Expected object of scalar type Long but got scalar type Int for argument #3 'index'

这是来自这一行。

targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)

我不确定该怎么做,因为我尝试使用多个位置将其转换为 long。我试着放一个

.long

最后以及将 dtype 设置为 torch.long 仍然不起作用。

与此非常相似,但他没有做任何事情来得到答案 "Expected Long but got Int" while running PyTorch script

我已经更改了很多代码,这是我的最后一个版本,但现在给了我同样的问题。

    def forward(self, inputs, targets):
            """
            Args:
                inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
                targets: ground truth labels with shape (num_classes)
            """
            log_probs = self.logsoftmax(inputs)
            targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
            if self.use_gpu: targets = targets.to(torch.device('cuda'))
            targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
            loss = (- targets * log_probs).mean(0).sum()
            return loss

【问题讨论】:

您提到的问题大约是 5 个月大。这次他们不太可能只是坐在那里。对问题发表评论,看看他们是否找到了解决方案 python 3 不再支持 Long,所以请尝试使用 int(var_name) 【参考方案1】:

您的索引参数(即targets.unsqueeze(1).data.cpu())的数据类型必须是torch.int64

(错误信息有点混乱:torch.long 不存在。但 PyTorch 内部的“Long”表示 int64)。

【讨论】:

我该怎么做?我知道在java中它像这样 int(x);但我用python尝试了类似的东西,但没有用。有什么建议?我也尝试将 dtype= int64 添加到它,但也没有用。 "我也尝试添加 dtype= int64 ,但也没有用。"你能显示你使用的确切代码吗? targets = torch.int64(torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)) targets = torch.zeros(log_probs.size()).scatter_(1, torch.int64(targets.unsqueeze(1).data.cpu()), 1) targets = torch.int64(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)【参考方案2】:
targets = torch.zeros(log_probs.size()).scatter_(1, (targets.unsqueeze(1).data.cpu()).long(), 1)

【讨论】:

你好!感谢分享答案!对于未来,为答案添加解释肯定会有所帮助! :)

以上是关于Pytorch 预期类型 Long 但得到类型 int的主要内容,如果未能解决你的问题,请参考以下文章

我更改了标量类型 float 的预期对象,但在 Pytorch 中仍然得到 Long

pytorch RuntimeError: 标量类型 Double 的预期对象,但得到标量类型 Float

RuntimeError: 标量类型 Long 的预期对象,但参数 #2 'mat2' 的标量类型 Float 如何解决?

如何修复pytorch'RuntimeError:类型为torch.cuda.LongTensor但发现类型为torch.LongTensor的预期对象'

RuntimeError:预期的标量类型 Long 但发现 Float

通过 DataLoader (PyTorch) 迭代:RuntimeError: 标量类型 unsigned char 的预期对象但序列元素 9 的标量类型浮点数