如何仅训练 RPN 以实现火炬视觉 Faster RCNN 与预训练主干
Posted
技术标签:
【中文标题】如何仅训练 RPN 以实现火炬视觉 Faster RCNN 与预训练主干【英文标题】:How to train only RPN for torch vision Faster RCNN with pretrained backbone 【发布时间】:2021-10-12 02:15:35 【问题描述】:如标题所述,如果我已经预训练了骨干网,并且我只想使用来自 torchvision 的 Faster R-CNN 训练 RPN 而不是分类器。
是否有任何参数可以传递给 create_model 函数,或者我会在我的 train() 函数中停止分类器训练?
我在手机上,所以请原谅我的编辑
这是我的创建模型函数
Create your backbone from timm
backbone = timm.create_model(
“resnet50”,
pretrained=True,
num_classes=0, # this is important to remove fc layers
global_pool="" # this is important to remove fc layers
)
backbone.out_channels = backbone.feature_info[-1][“num_chs”]
anchor_generator = AnchorGenerator(
sizes=((16, 32, 64, 128, 256),), aspect_ratios=((0.25, 0.5, 1.0, 2.0),)
)
roi_pooler = torchvision.ops.MultiScaleRoIAlign(
featmap_names=[“0”], output_size=7, sampling_ratio=2
)
fastercnn_model = FasterRCNN(
backbone=backbone,
num_classes=1000,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler,
)
【问题讨论】:
【参考方案1】:您可以执行以下操作
# First you can use model.children() method to see the idx of the backbone
for idx, child in enumerate(fastercnn_model.children()):
if idx == 1:
# Now set requires_grad for that idx to False
for param in child.parameters():
param.requires_grad = False
break
# =============== UPDATED ========================
# This will train only the box_predictor not even the RPN. You can try out
# Different strategies and find the best for you.
# setting everything to false
for child in fastercnn_model.children():
for param in child.parameters():
param.requires_grad = False
for idx, child in enumerate(fastercnn_model.children()):
if idx == 3:
for i, param in enumerate(child.parameters()):
if i==1:
param.requires_grad = True
break
【讨论】:
感谢您的回复,我是否也需要将主干设置为评估模式? @AnhMinhTran No. 完成此操作后,您所做的一切将与您的整个模型相对应。 听从您的建议后,我尝试运行 Faster RCNN,在度量日志期间,分类器损失仍在更新,这是否表明模型正在尝试训练主干?请在下面的链接中找到图片imgur.com/a/gSIpXfQ 我们确实在训练模型,除了主干,其余的都会更新。 我已经用更多解释更新了我的答案。以上是关于如何仅训练 RPN 以实现火炬视觉 Faster RCNN 与预训练主干的主要内容,如果未能解决你的问题,请参考以下文章
Faster RCNN超详细入门 02网络细节与训练方法 (anchors,RPN,bbox,bounding box,Region proposal layer……)