有啥方法可以有效地堆叠/集成用于图像分类的预训练模型?

Posted

技术标签:

【中文标题】有啥方法可以有效地堆叠/集成用于图像分类的预训练模型?【英文标题】:Any way to efficiently stack/ensemble pre-trained models for image classification?有什么方法可以有效地堆叠/集成用于图像分类的预训练模型? 【发布时间】:2021-12-24 22:44:26 【问题描述】:

我正在尝试通过获取每个模型的最后一个隐藏层,然后将它们连接在一起,然后将它们插入元学习器模型(例如 XGBoost)来堆叠我拥有的一些预训练模型。

我遇到了一个大问题,即必须多次处理我的数据集的每个图像,因为每个基本模型都需要不同的处理方法。这导致我的模型需要很长时间才能训练并且不可行。有什么办法可以解决这个问题吗?

例如:

model_1, processor_1 = pretrained_model(), pretrained_processor()
model_2, processor_2 = pretrained_model2(), pretrained_processor2()

for img in images:

input_1 = processor_1(img)
input_2 = processor_2(img)

out_1 = model_1(input_1)
out_2 = model_2(input_2)

torch.cat((out1,out2), dim=1) #concatenates hidden representations to feed into another model

【问题讨论】:

【参考方案1】:

如果您想更快地处理图像,这里有一个建议:

注意:我没有对此进行测试

import torch
import torch.nn as nn

# Create a stack nn module
class StackedModel(nn.Module):
  def __init__(self, model1, model2):
    super(StackedModel, self).__init__()

    self.model1 = model1
    self.model2 = model2

  def forward(self, imgs):
    out_1 = model_1(input_1)
    out_2 = model_2(input_2)

    return torch.cat((out1, out2), dim=1)

# Init model
model = StackedModel(model1, model2)

# Try to stack and run in a larger batch assuming u have extra gpu space
stacked_preproc1 = []
stacked_preproc2 = []
max_batch_size = 16
total_output = []

for index, img in enumerate(images):
  input_1 = processor_1(img)
  input_2 = processor_2(img)

  stacked_preproc1.append(input_1)
  stakced_preproc2.appennd(input2)

  if index % max_batch_size == 0:
    stacked_preproc1 = torch.stack(stacked_preproc1)
    stakced_preproc2 = torch.stack(stakced_preproc2)
  else:
    total_output.append(
        model(stacked_preproc1, stacked_preproc2)
    )

    # Reset array
    stacked_preproc1 = []
    stakced_preproc2 = []
      

【讨论】:

以上是关于有啥方法可以有效地堆叠/集成用于图像分类的预训练模型?的主要内容,如果未能解决你的问题,请参考以下文章

加载Pytorch中的预训练模型及部分结构的导入

用于分类的预分类训练 Twitter 评论

有啥方法可以将 PyTorch 中可用的预训练模型下载到特定路径?

使用 BERT 等预训练模型进行文档分类

多标签分类的预训练

如何使用 bagging 集成 SVM 和 CNN 分类器?