如何修复pytorch'RuntimeError:类型为torch.cuda.LongTensor但发现类型为torch.LongTensor的预期对象'

Posted

技术标签:

【中文标题】如何修复pytorch\'RuntimeError:类型为torch.cuda.LongTensor但发现类型为torch.LongTensor的预期对象\'【英文标题】:How to fix pytorch 'RuntimeError: Expected object of type torch.cuda.LongTensor but found type torch.LongTensor'如何修复pytorch'RuntimeError:类型为torch.cuda.LongTensor但发现类型为torch.LongTensor的预期对象' 【发布时间】:2020-02-29 14:05:32 【问题描述】:

我正在尝试使用 FloydHub 的 GPU 运行 this 代码。 当我在 train_model 文件夹下运行 train.py 脚本时,我得到了提到的 RuntimeError。

这是完整的回溯:

回溯(最近一次通话最后一次): 文件“./train_model/train.py”,第 79 行,在 答案 = 模型(批次)调用中的文件“/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py”,第 477 行 结果 = self.forward(*input, **kwargs) 文件“/floyd/home/train_model/model.py”,第 29 行,向前 vecs = self.embed(batch.text)调用中的文件“/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py”,第 477 行 结果 = self.forward(*input, **kwargs) 文件“/usr/local/lib/python3.6/site-packages/torch/nn/modules/sparse.py”,第 110 行,向前 self.norm_type, self.scale_grad_by_freq, self.sparse) 文件“/usr/local/lib/python3.6/site-packages/torch/nn/functional.py”,第 1110 行,嵌入 返回 torch.embedding(权重、输入、padding_idx、scale_grad_by_freq、稀疏) RuntimeError:预期为 torch.cuda.LongTensor 类型的对象,但发现参数 #3 'index' 的类型为 torch.LongTensor

我了解部分代码正在使用 GPU,而其他部分未使用但不知道如何识别这些代码并让所有代码在 GPU 上运行。

请帮忙!

【问题讨论】:

【参考方案1】:

在调用 forward 函数时出现错误。正如错误所说,转发函数“类型为 torch.cuda.LongTensor 的预期对象”,我认为您的输入 batch 仍在 CPU 中,需要转移到 cuda 设备。

我觉得你已经知道怎么做,但如果你不知道,请阅读https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#training-on-gpu。

希望对您有所帮助,如果您需要更多帮助,请告诉我。 :)

【讨论】:

我们如何将batch 发送到 GPU?我试过batch = batch.cuda()batch.to(device)。但它们不起作用,因为没有这样的属性 不应该这样吗?在调用forward函数之前,能不能先声明print(type(batch)),让我知道它是张量还是numpy数组还是别的什么? 这是type(batch)<class 'torchtext.data.batch.Batch'>的内容 forward函数的输入需要Pytorch张量,而batch是torchtext.data.batch.Batch类的对象。您的输入数据必须是对象实例batch 的属性之一。通常,torchtext,包含输入的属性是名称src,在您的情况下应该以batch.src 访问。告诉我。 这是torchtext.data.batch 的属性列表:torchtext.readthedocs.io/en/latest/… 不确定其中哪一个是张量【参考方案2】:

我通过在 Tensorflow 2 中运行我的代码解决了这个错误,所以它可能是一个 TF 版本问题

【讨论】:

您的答案可以通过额外的支持信息得到改进。请edit 添加更多详细信息,例如引用或文档,以便其他人可以确认您的答案是正确的。你可以找到更多关于如何写好答案的信息in the help center。

以上是关于如何修复pytorch'RuntimeError:类型为torch.cuda.LongTensor但发现类型为torch.LongTensor的预期对象'的主要内容,如果未能解决你的问题,请参考以下文章

如何修复drv?

如何修复漏洞

如何修复WMI

PHP网站漏洞怎么修复 如何修补网站程序代码漏洞

如何修复这些漏洞? (npm audit fix 无法修复这些漏洞)

如何修复AppScan漏洞