让 RESNet18 与 float32 数据一起工作 [重复]
Posted
技术标签:
【中文标题】让 RESNet18 与 float32 数据一起工作 [重复]【英文标题】:Getting RESNet18 to work with float32 data [duplicate] 【发布时间】:2021-11-17 06:25:17 【问题描述】:我有 float32 数据,我正在尝试让 RESNet18 使用。我在 torchvision 中使用 RESNet 模型(并使用 pytorch 闪电)并将其修改为使用一层(灰度)数据,如下所示:
class ResNetMSTAR(pl.LightningModule):
def __init__(self):
super().__init__()
# define model and loss
self.model = resnet18(num_classes=3)
self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.loss = nn.CrossEntropyLoss()
@auto_move_data # this decorator automatically handles moving your tensors to GPU if required
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_no):
# implement single training step
x, y = batch
logits = self(x)
loss = self.loss(logits, y)
return loss
def configure_optimizers(self):
# choose your optimizer
return torch.optim.RMSprop(self.parameters(), lr=0.005)
当我尝试运行此模型时,我收到以下错误:
File "/usr/local/lib64/python3.6/site-packages/torch/nn/functional.py", line 2824, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward
我有什么不同的方法可以防止这个错误发生吗?
【问题讨论】:
【参考方案1】:问题是 y
你喂你的交叉熵损失,不是一个 LongTensor,而是一个 FloatTensor。 CrossEntropy 期望为目标提供一个 LongTensor,并引发错误。
这是一个丑陋的修复:
x, y = batch
y = y.long()
但我建议您转到定义数据集的位置,并确保生成长目标,这样如果您更改训练循环的工作方式,您将不会重现此错误。
【讨论】:
以上是关于让 RESNet18 与 float32 数据一起工作 [重复]的主要内容,如果未能解决你的问题,请参考以下文章
知识蒸馏IRG算法实战:使用ResNet50蒸馏ResNet18
知识蒸馏NST算法实战:使用CoatNet蒸馏ResNet18