如何使用切片在图像之间移动特征图?
Posted
技术标签:
【中文标题】如何使用切片在图像之间移动特征图?【英文标题】:How to move feature maps across images using slicing? 【发布时间】:2021-08-26 17:46:18 【问题描述】:我正在尝试实现这个paper的在线算法,它是关于视频分类的。在每次卷积操作之后,这项工作将 1/8 的通道特征图从每个图像移动到下一个图像中。操作的图像已附在此处-
虽然尝试实现相同的功能,但我已成功提取出前 1/8 通道特征图,但我不知道如何将它们添加到后续图像中。我的代码已附在下面 -
import cv2
import gym
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F
N = 1 # Batch Size
T = 5 # Time Steps. This means that there are 5 frames in the video
C = 3 # RGB Channels
H = 144 # Height
W = 144 # Width
foo = torch.randn(N*T, C, H, W)
print("Shape of foo = ", foo.shape)
#torch.Size([5, 3, 144, 144])
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 8, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
print("Shape of x = ", x.shape)
# torch.Size([5, 8, 140, 140])
shape_extract = x[:, :1,:,:]
print("Shape of extract = ", shape_extract.shape)
# torch.Size([5, 1, 140, 140])
# 1/8 of the channels have been extracted out from above. But how do I transfer these channel features to the next image?
return x
net = Net()
output = net(foo)
【问题讨论】:
【参考方案1】:由于您的整个序列都在批次内,您可以使用torch.roll
第一个轴上的元素移动图层。
>>> rolled = x.roll(shifts=1, dims=1)
从axis=1
上的这一层布局开始:
[x_0, x_1, x_2, x_3, ..., x_7]
到这个:
[x_7, x_0, x_1, x_2, ..., x_6]
然后将第一个元素替换为x_0
:
>>> rolled[:, 0] = x[:, 0]
导致这种布局:
[x_0, x_0, x_1, x_2, ..., x_6]
然后你可以输入张量rolled
到下一层。
你可以实现一个自定义层来包装这个逻辑:
class ShiftLayer(nn.Module):
def forward(self, x):
out = x.roll(1, 1)
out[:, 0] = x[:, 0]
return out
然后在你的模型中使用它:
class Net(nn.Module):
def __init__(self):
super().__init__()
...
self.shift = ShiftLayer()
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.shift(x)
x = F.relu(self.conv2(x))
return x
【讨论】:
以上是关于如何使用切片在图像之间移动特征图?的主要内容,如果未能解决你的问题,请参考以下文章
[OpenCV实战]19 使用OpenCV实现基于特征的图像对齐
如何在 Javascript/jQuery 中动态切片数组?