Resnet 迁移学习记录

Posted TOPthemaster

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Resnet 迁移学习记录相关的知识,希望对你有一定的参考价值。

在实际应用中,cnn网络的训练是很繁琐且浪费时间的,这时候我们一般会去选择加载网上已经训练得很完善的网络作为自己的cnn网络层,下面例子为使用Resnet预训练模型来做自己的图片分类:

# 网络定义 
class Resnet(nn.Module):

    def __init__(self):
        super(Resnet, self).__init__()
        pretrained_net = torchvision.models.resnet18(pretrained=True)
        model =nn.Sequential(*list(pretrained_net.children())[:-1])
        self.model = model
        self.Linear = nn. Linear(in_features=512, out_features=10, bias=True)
    def forward(self, x):
        x=self.model(x)
        # 这里有个bug,在下载的预训练网络最后一层中,只显示了线性层,但是如果你直接添加一个线性层,会报错,原因为维度的不一致,需要view到适配维度。
        x = x.view(-1, 512)
        x=self.Linear(x)
        return x
X = torch.rand(size=(1, 3, 224, 224))
model=Resnet()
print(model)
print(model(X).shape)

然后进行训练,对比于之前的自己构建的网络重新训练来看,会发现收敛特别快,且很容易得到自己想要的ACC。

[1,   500] loss: 1.301
train_correct=
0.582
train time: 26.080318927764893
Accuracy on test set: 74.03  %
[1,  1000] loss: 0.772
train_correct=
0.7533333333333333
train time: 93.63580584526062
Accuracy on test set: 75.72  %
[1,  1500] loss: 0.697
train_correct=
0.7773333333333333
train time: 159.9850172996521
Accuracy on test set: 78.19  %

这里只跑了1500x3张图,连1/10个epoch都没跑到,但效果已经很强大了,且收敛得非常快速。

以上是关于Resnet 迁移学习记录的主要内容,如果未能解决你的问题,请参考以下文章

Resnet 迁移学习记录

Resnet 迁移学习记录

ResNet 基于迁移学习对CIFAR10 数据集的分类

ResNet18迁移学习CIFAR10分类任务(附python代码)

pytorch--resnet 精准迁移学习 花朵识别

pytorch--resnet 精准迁移学习 花朵识别