如何使用切片在图像之间移动特征图?

Posted

技术标签:

【中文标题】如何使用切片在图像之间移动特征图?【英文标题】:How to move feature maps across images using slicing? 【发布时间】:2021-08-26 17:46:18 【问题描述】:

我正在尝试实现这个paper的在线算法,它是关于视频分类的。在每次卷积操作之后,这项工作将 1/8 的通道特征图从每个图像移动到下一个图像中。操作的图像已附在此处-

虽然尝试实现相同的功能,但我已成功提取出前 1/8 通道特征图,但我不知道如何将它们添加到后续图像中。我的代码已附在下面 -

import cv2
import gym
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F

N = 1 # Batch Size
T = 5 # Time Steps. This means that there are 5 frames in the video
C = 3 # RGB Channels
H = 144 # Height
W = 144 # Width

foo = torch.randn(N*T, C, H, W)

print("Shape of foo = ", foo.shape)
#torch.Size([5, 3, 144, 144])

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        print("Shape of x = ", x.shape)
        # torch.Size([5, 8, 140, 140])
        shape_extract = x[:, :1,:,:]
        print("Shape of extract = ", shape_extract.shape)
        # torch.Size([5, 1, 140, 140])
        # 1/8 of the channels have been extracted out from above. But how do I transfer these channel features to the next image?

        return x

net = Net()
output = net(foo)

【问题讨论】:

【参考方案1】:

由于您的整个序列都在批次内,您可以使用torch.roll 第一个轴上的元素移动图层。

>>> rolled = x.roll(shifts=1, dims=1)

axis=1上的这一层布局开始:

[x_0, x_1, x_2, x_3, ..., x_7]

到这个:

[x_7, x_0, x_1, x_2, ..., x_6]

然后将第一个元素替换为x_0:

>>> rolled[:, 0] = x[:, 0]

导致这种布局:

[x_0, x_0, x_1, x_2, ..., x_6]

然后你可以输入张量rolled到下一层。


你可以实现一个自定义层来包装这个逻辑:

class ShiftLayer(nn.Module):
    def forward(self, x):
        out = x.roll(1, 1)
        out[:, 0] = x[:, 0]
        return out

然后在你的模型中使用它:

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        ...
        self.shift = ShiftLayer()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.shift(x)
        x = F.relu(self.conv2(x))
        return x

【讨论】:

以上是关于如何使用切片在图像之间移动特征图?的主要内容,如果未能解决你的问题,请参考以下文章

[OpenCV实战]19 使用OpenCV实现基于特征的图像对齐

如何从视频中提取某一帧的图像

如何从视频中提取某一帧的图像

如何在 Javascript/jQuery 中动态切片数组?

如何通过将地图切片图像保存到 sqlite 数据库来使用 osmdroid 实现离线地图?

基于深度学习的图异常检测如何改进