pytorch RuntimeError: 标量类型 Double 的预期对象,但得到标量类型 Float

Posted

技术标签:

【中文标题】pytorch RuntimeError: 标量类型 Double 的预期对象,但得到标量类型 Float【英文标题】:pytorch RuntimeError: Expected object of scalar type Double but got scalar type Float 【发布时间】:2020-05-31 00:06:08 【问题描述】:

我正在尝试为我的神经网络实现一个自定义数据集。但是在运行转发功能时出现此错误。代码如下。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np

class ParamData(Dataset):
    def __init__(self,file_name):
        self.data = torch.Tensor(np.loadtxt(file_name,delimiter = ','))    #first place
    def __len__(self):
        return self.data.size()[0]
    def __getitem__(self,i):
        return self.data[i]

class Net(nn.Module):
    def __init__(self,in_size,out_size,layer_size=200):
        super(Net,self).__init__()
        self.layer = nn.Linear(in_size,layer_size)
        self.out_layer = nn.Linear(layer_size,out_size)

    def forward(self,x):
        x = F.relu(self.layer(x))
        x = self.out_layer(x)
        return x

datafile = 'data1.txt'

net = Net(100,1)
dataset = ParamData(datafile)
n_samples = len(dataset)

#dataset = torch.Tensor(dataset,dtype=torch.double)   #second place
#net.float()                                          #thrid place

net.forward(dataset[0])         #fourth place

在文件data1.txt 中是一个包含特定数字的csv 格式文本文件,每个dataset[i] 是一个大小为100 x 1 的torch.Tensor dtype 对象torch.float64。错误信息如下:

Traceback (most recent call last):
  File "Z:\Wrong.py", line 33, in <module>
    net.forward(dataset[0])
  File "Z:\Wrong.py", line 23, in forward
    x = F.relu(self.layer(x))
  File "E:\Python38\lib\site-packages\torch\nn\modules\module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "E:\Python38\lib\site-packages\torch\nn\modules\linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "E:\Python38\lib\site-packages\torch\nn\functional.py", line 1372, in linear
    output = input.matmul(weight.t())
RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'mat2' in call to _th_mm

看来我应该将dataset 中数字的dtype 更改为torch.double。我试过像

将第一行改为self.data = torch.tensor(np.loadtxt(file_name,delimiter = ','),dtype=torch.double)

将第四位的行改为net.forward(dataset[0].double())

在第二个或第三个位置取消注释两行之一

我认为这些是我从类似问题中看到的解决方案,但它们要么给出新的错误,要么什么都不做。我该怎么办?


更新:所以我把第一个位置改为

self.data = torch.from_numpy(np.loadtxt(file_name,delimiter = ',')).float()

这很奇怪,因为它与错误消息完全相反。这是一个错误吗?我还是想解释一下。

【问题讨论】:

.float() 在将输入转换为火炬张量后也对我有用。 也为我工作过......我因类似原因感到困惑 【参考方案1】:

简而言之:您的数据类型为 double,但您的模型类型为 float,这在 pytorch 中是不允许的,因为只有带有可以将相同的 dtype 输入到模型中。

长期: 此问题与 PyTorch 和 Numpy 的默认 dtype 有关。我会先解释为什么会出现这个错误,然后提出一些解决方案(但我认为一旦你理解了原理,你就不需要我的解决方案了。)

PyTorch 有几个 dtypes https://pytorch.org/docs/stable/tensors.html。其中两个与您遇到的问题密切相关:
    torch.float32(又名torch.floattorch.float64(又名torch.double

知道 PyTorch 张量的默认 dtype 是torch.float32(又名torch.float)很重要。这意味着当你创建一个张量时,它的默认 dtype 是 torch.float32.try: torch.ones(1).dtype 。这将在默认情况下打印torch.float32。而且模型的参数默认也是这个dtype。

在您的情况下,net = Net(100,1) 将创建一个模型,其参数的 dtype 为 torch.float32

那么我们需要谈谈 Numpy:

Numpy ndarray 的默认 dtype 是numpy.float64。这意味着当您创建一个 numpy 数组时,它的默认 dtype 是 numpy.float64.try: np.ones(1).dtype 。这将在默认情况下打印dtype('float64')

在您的情况下,您的数据来自np.loadtxt 加载的本地文件,因此数据首先以dtype('float64')(作为numpy 数组)加载,然后转换为dtype torch.float64(aka torch.double)。当您将 numpy 数组转换为 torch 张量时会发生这种情况:它们将具有相应的 dtype。

我认为现在问题已经很清楚了,您有一个模型,其参数为torch.float32(又名torch.float),但试图在torch.float64(又名torch.double)的数据上运行它。这也是错误信息试图表达的意思:Expected object of scalar type Double but got scalar type Float for argument

解决方案:

    您已经找到了:通过调用tensor.float() 将您的数据转换为torch.float32 你也可以在加载数据时指定dtype:np.loadtxt(file_name,delimiter = ',',dtype="float32")

【讨论】:

【参考方案2】:

现在我对pytorch有了更多的经验,我想我可以解释错误信息了。好像这条线

RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'mat2' in call to _th_mm

实际上是指调用矩阵乘法时线性层的权重。由于输入是double,而权重是float,所以这条线有意义

output = input.matmul(weight.t())

期望权重为double

【讨论】:

以上是关于pytorch RuntimeError: 标量类型 Double 的预期对象,但得到标量类型 Float的主要内容,如果未能解决你的问题,请参考以下文章

pytorch RuntimeError: 标量类型 Double 的预期对象,但得到标量类型 Float

Pytorch RuntimeError:参数#1 'indices' 的预期张量具有标量类型 Long;但得到了 CUDAType

RuntimeError:预期的标量类型 Double 但发现 Float

RuntimeError: 标量类型 Long 的预期对象,但参数 #2 'mat2' 的标量类型 Float 如何解决?

PyTorch:RuntimeError:输入、输出和索引必须在当前设备上

RuntimeError:预期的标量类型 Long 但发现 Float