修改特征图类型tuple转Tensor

Posted 一颗小树x

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了修改特征图类型tuple转Tensor相关的知识,希望对你有一定的参考价值。

前言

在修改模型结构时,本来想着简单替换主干网络,用轻量级结构的替换原来的复杂模型,但是过程没想象中的顺利;其中比较关键的一点是两个主干网络输出的特征图类型不一致。

问题描述

主干网络A(轻量级),它输出特征图的类型是tuple,输出维度是[1, 3, 640, 640];

主干网络B(复杂的),它输出特征图的类型是torch.Tensor,输出维度也是[1, 3, 640, 640];

但是如果直接把主干网络B替换为主干网络A,后面接着原来的特征提取结构和任务头,会报错的。

tuple 转 torch.Tensor

把主干网络B替换为主干网络A后,加多一步操作,将输出特征图从tuple 转 torch.Tensor即可。

转换的基本思路是:使用 torch.cat( ) 把特征图进行拼接起来,通常是在维度 dim=0 进行拼接的。

A、当特征图的tuple数量为1

import torch

# 假设模型输出的特征图为 feature_map, feature_map 是一个 tuple

# 获取特征图个数
num_maps = len(feature_map)

# 打印原来的特征图信息
print("type feature_raw:", type(outs))
for out in feature_map:
    print(out.size())
print("len feature_raw:", num_maps)

# 按第 0 维度拼接特征图
feature_map = torch.cat([fm for fm in feature_map], dim=0)

# 检查特征图类型
print("type feature_map:", type(feature_map))
# 输出: <class 'torch.Tensor'>

# 检查特征图维度
print("size feature_map:", feature_map.size())

示例输出:

type feature_raw: <class 'tuple'>
torch.Size([8, 32, 640, 640])
len feature_raw: 1


type feature_map: <class 'torch.Tensor'>
feature_map: torch.Size([8, 32, 640, 640])

B、当特征图的tuple数量为多个

如果主干网络输出的特征图类型为tuple,而且它包含多个特征图。我们想把它们变为一个torch.Tensor,可以使用torch.cat函数把它们拼接在一起。 

import torch

# 假设模型输出的特征图为 feature_map, feature_map 是一个 tuple

# 获取特征图个数
num_maps = len(feature_map)

# 打印原来的特征图信息
print("type feature_raw:", type(outs))
for out in feature_map:
    print(out.size())
print("len feature_raw:", num_maps)

# 按第 0 维度拼接特征图
feature_map = torch.cat([fm.unsqueeze(0) for fm in feature_map], dim=0)

# 检查特征图类型
print("type feature_map:", type(feature_map))
# 输出: <class 'torch.Tensor'>

# 检查特征图维度
print("size feature_map:", feature_map.size())

这样就可以将输出的特征图类型由tuple变为torch.Tensor了。拼接时,通过unsqueeze(0)把每个特征图在第0维度上增加一维,这样才能用torch.cat进行拼接。 

示例输出:

type feature_raw: <class 'tuple'>
torch.Size([8, 32, 640, 640])

torch.Size([8, 32, 640, 640])

torch.Size([8, 32, 640, 640])

torch.Size([8, 32, 640, 640])
len feature_raw: 1


type feature_map: <class 'torch.Tensor'>
feature_map: torch.Size([4, 8, 32, 640, 640])

分享完成,欢迎交流~

 

以上是关于修改特征图类型tuple转Tensor的主要内容,如果未能解决你的问题,请参考以下文章

python 中各类型介绍及相互转换 - list, array, tensor, dict, tuple, DataFrame

Java Tuple使用实例(转)

Python中的元组(Tuple)

数据类型 tuple 的常用方法

数据类型 -- Tuple(元组)

Python的数据类型:list和tuple