让 RESNet18 与 float32 数据一起工作 [重复]

Posted

技术标签:

【中文标题】让 RESNet18 与 float32 数据一起工作 [重复]【英文标题】:Getting RESNet18 to work with float32 data [duplicate] 【发布时间】:2021-11-17 06:25:17 【问题描述】:

我有 float32 数据,我正在尝试让 RESNet18 使用。我在 torchvision 中使用 RESNet 模型(并使用 pytorch 闪电)并将其修改为使用一层(灰度)数据,如下所示:

class ResNetMSTAR(pl.LightningModule):
def __init__(self):
  super().__init__()
  # define model and loss
  self.model = resnet18(num_classes=3)
  self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  self.loss = nn.CrossEntropyLoss()

@auto_move_data # this decorator automatically handles moving your tensors to GPU if required
def forward(self, x):
return self.model(x)

def training_step(self, batch, batch_no):
  # implement single training step
  x, y = batch
  logits = self(x)
  loss = self.loss(logits, y)
  return loss

def configure_optimizers(self):
  # choose your optimizer
  return torch.optim.RMSprop(self.parameters(), lr=0.005)

当我尝试运行此模型时,我收到以下错误:

File "/usr/local/lib64/python3.6/site-packages/torch/nn/functional.py", line 2824, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward

我有什么不同的方法可以防止这个错误发生吗?

【问题讨论】:

【参考方案1】:

问题是 y 你喂你的交叉熵损失,不是一个 LongTensor,而是一个 FloatTensor。 CrossEntropy 期望为目标提供一个 LongTensor,并引发错误。

这是一个丑陋的修复:

x, y = batch
y = y.long()

但我建议您转到定义数据集的位置,并确保生成长目标,这样如果您更改训练循环的工作方式,您将不会重现此错误。

【讨论】:

以上是关于让 RESNet18 与 float32 数据一起工作 [重复]的主要内容,如果未能解决你的问题,请参考以下文章

知识蒸馏IRG算法实战:使用ResNet50蒸馏ResNet18

知识蒸馏NST算法实战:使用CoatNet蒸馏ResNet18

知识蒸馏NST算法实战:使用CoatNet蒸馏ResNet18

让 qmake 与 32 位和 64 位并排安装一起工作

resnet18和resnet101的区别

ResNet18迁移学习CIFAR10分类任务(附python代码)