针对新数据点更新预训练的深度学习模型
Posted
技术标签:
【中文标题】针对新数据点更新预训练的深度学习模型【英文标题】:Updating pre-trained Deep Learning model with respect to new data points 【发布时间】:2019-05-06 13:51:54 【问题描述】:以 ImageNet 上的图像分类为例,如何使用新数据点更新预训练模型。 我已经加载了预训练模型。我有一个新数据点,它与之前训练模型的原始数据的分布完全不同。所以,我想在新数据点的帮助下更新/微调模型。如何去做?谁能帮我做这件事?我正在使用 pytorch 0.4.0 进行实现,在 GPU Tesla K40C 上运行。
【问题讨论】:
【参考方案1】:如果您不想更改分类器的输出(即类的数量),那么您可以简单地继续使用新的示例图像训练模型,假设它们被重新塑造成与预训练模型相同的形状接受。
另一方面,如果您想更改预训练模型中的类数,则可以用新的完全连接层替换最后一个全连接层,并仅在新样本上训练该特定层。以下是来自PyTorch's autograd mechanics notes 的此案例的示例代码:
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)
# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
【讨论】:
以上是关于针对新数据点更新预训练的深度学习模型的主要内容,如果未能解决你的问题,请参考以下文章
预训练模型代码深度剖析之开宗明义:新学常见误区和正确的学习姿势
预训练模型代码深度剖析之开宗明义:新学常见误区和正确的学习姿势
预训练模型代码深度剖析之开宗明义:新学常见误区和正确的学习姿势
深度学习100例 | 第33天:迁移学习-实战案例教程(必须掌握的一个点)