MatchNet论文复现过程记录
Posted __init__:
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了MatchNet论文复现过程记录相关的知识,希望对你有一定的参考价值。
MatchNet论文复现过程记录
原文为《Matchnet: Unifying feature and metric learning for patch-based matching》1:本文复现基于PyTorch深度学习框架,版本(1.7.1+cu110)。
I.Network architecture
根据论文中描述,MatchNet包括:
A. Feature network
该特征提取网络类似AlexNet2,具体结构如下:
其中,PS: patch size for convolution and pooling layers; S: stride. Layer types: C: convolution, MP: max-pooling, FC: fully-connected.
B. Metric network
包括三个全连接层,FC3后接Softmax作为输出。
C. MatchNet in training
基于patch的匹配任务通常假设patch在计算相似度之前,先经过相同的特征编码。因此,论文中采用Two-tower structure with tied parameters结构,即,仅采用一个特征提取网络,在训练过程中,可以理解为同时使用了两个参数共享的特征提取网络去连接度量网络,更新任何一个特征提取网络,将会使得两个网络的参数都发生变化。(这里直接讲比较难理解,具体可以看代码实现。)
具体代码实现如下:
import torch
import torch.nn as nn
class FeatureNet(nn.Module):
"""特征提取网络
"""
def __init__(self):
super(FeatureNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=24, kernel_size=7, padding=3, stride=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
nn.Conv2d(in_channels=24, out_channels=64, kernel_size=5, padding=2, stride=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, padding=1, stride=1),
nn.ReLU(),
nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, padding=1, stride=1),
nn.ReLU(),
nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, padding=1, stride=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
)
def forward(self, x):
return self.features(x)
class MetricNet(nn.Module):
"""度量网络
"""
def __init__(self):
super(MetricNet, self).__init__()
self.features = nn.Sequential(
nn.Linear(in_features=6272, out_features=1024),
nn.ReLU(),
nn.Linear(in_features=1024, out_features=1024),
nn.ReLU(),
nn.Linear(in_features=1024, out_features=2),
# nn.Softmax(dim=1)
''' 这里原本应该接Softmax,但损失函数采用的是交叉熵损失,
而Pytorch中的torch.nn.CrossEntropyLoss()方法包括Softmax,
具体可参考文档https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=nn%20crossentropyloss#torch.nn.CrossEntropyLoss
'''
)
def forward(self, x):
return self.features(x)
class MatchNet(nn.Module):
def __init__(self):
super(MatchNet, self).__init__()
# 只添加一个特征提取网络
self.input_ = FeatureNet()
self.input_.apply(weights_init)
self.matric_network = MetricNet()
self.matric_network.apply(weights_init)
def forward(self, x):
"""x.shape = (2, C, H, W),即两个patch
"""
# 两个patch进入同一个FeatureNet,相当于two-tower sharing same parameters
feature1 = self.input_(x[0]).reshape((x[0].shape[0], -1)) #[256, 3136]
feature2 = self.input_(x[1]).reshape((x[1].shape[0], -1))
features = torch.cat((feature1, feature2), 1) #[256, 6272]
return self.matric_network(features)
def weights_init(m):
'''
自定义权重初始化
'''
if isinstance(m, nn.Conv2d):
nn.init.orthogonal_(m.weight.data, gain=0.6)
try:
nn.init.constant_(m.bias.data, 0.01)
except Exception:
pass
return
参考文献
以上是关于MatchNet论文复现过程记录的主要内容,如果未能解决你的问题,请参考以下文章