通俗易懂的Spatial Transformer Networks(STN)

Posted 修炼之路

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了通俗易懂的Spatial Transformer Networks(STN)相关的知识,希望对你有一定的参考价值。

导读

pytorch为了方便实现STN,里面封装了affine_gridgrid_sample两个高级API。对STN不太了解的同学可以参考这篇详细解读Spatial Transformer Networks(STN)

其实STN的作用是想让CNN具备平移、旋转、缩放、剪切不变性,虽然说CNN中的Pooling可以让网络具备一点平移不变性,但这毕竟是隐性的,如果能让网络直接具备这样的能力岂不是更好。

如果对图像处理有了解的同学也许听过仿射变换这个名词,我们只需要通过变换矩阵 θ \\theta θ(由6个参数组成)就能实现上面的这些功能,如果对仿射变换不了解的同学可以参考我的这篇一文搞懂仿射变换

STN也是因为受到这个启发而诞生的,那么我们如何将这种能力嵌入到CNN中呢?这便是STN需要解决的问题

STN简介

上面引用的文章中已经详细介绍了STN网络,我这里总结概括一下

  • Localisation net

Localisation net模块通过CNN提取图像的特征来预测变换矩阵 θ \\theta θ

  • Grid generator

Grid generator模块就是利用Localisation net模块回归出来的 θ \\theta θ参数来对图片中的位置进行变换,输入图片到输出图片之间的变换,需要特别注意的是这里指的是图片像素所对应的位置

例如:如果此时 θ \\theta θ参数功能是实现图片的平移变换(向右平移1,),输入图片上的坐标(1,1),那对应输出图片上的坐标的(2,1),也就是说输入图片上(1,1)对应的像素值等于输出图片上(2,1)对应的像素值。在变换的时候必然会遇到当输入图片的位置变换到输出图片上是如果位置出现小数怎么办?

  • Sampler

Sampler就是用来解决Grid generator模块变换出现小数位置的问题的。针对这种情况,STN采用的是双线性插值(Bilinear Interpolation),下面我们来介绍一下这个算法

上图中 ( x , y ) (x,y) (x,y)是变换后输出图像上的位置,带下标的坐标位置表示的是与 ( x , y ) (x,y) (x,y)在输入图像对应的四个相邻的坐标。上面的坐标满足下面的关系
x 1 − x 0 = 1 y 1 − y 0 = 1 x_1-x_0 = 1\\\\ y1-y_0 = 1 x1x0=1y1y0=1
根据双线性插值的原则距离相邻点近的坐标占的比重越大,所以 ( x , y ) (x,y) (x,y)对应的像素值为,我们用 f ( x , y ) f(x,y) f(x,y)表示点 ( x , y ) (x,y) (x,y)所对应的像素值
f ( x , y ) = ( x 1 − x ) ( y 1 − y ) f ( x 0 , y 0 ) + ( x − x 0 ) ( y 1 − y ) f ( x 1 , y 0 ) = + ( x − x 0 ) ( y − y 0 ) f ( x 1 , y 1 ) + ( x 1 − x ) ( y − y 0 ) f ( x 0 , y 1 ) \\beginaligned f(x,y) &= (x_1-x)(y1-y)f(x_0,y_0)+(x-x_0)(y_1-y)f(x_1,y_0)\\\\ &=+(x-x_0)(y-y_0)f(x_1,y_1)+(x_1-x)(y-y_0)f(x_0,y_1) \\endaligned f(x,y)=(x1x)(y1y)f(x0,y0)+(xx0)(y1y)f(x1,y0)=+(xx0)(yy0)f(x1,y1)+(x1x)(yy0)f(x0,y1)

STN层的实现

  • pytorch的实现

通过pytorchaffine_gridgrid_sample可以很容易实现STN的后两个模块

from torchvision import transforms
import torch.nn.functional as F
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

#读取图片
img = Image.open("img/test.jpg")
#将图片转换为torch tensor
img_tensor = transforms.ToTensor()(img)

#定义平移变换矩阵
#0.1表示将图片向左平移图片宽的百分比
#0.2表示将图片向上平移图片高的百分比
theta = torch.tensor([[1,0,0.1],[0,1,0.2]],
                     dtype=torch.float)
#根据变换矩阵来计算变换后图片的对应位置
grid = F.affine_grid(theta.unsqueeze(0),
               img_tensor.unsqueeze(0).size(),align_corners=True)
#默认使用双向性插值,可以通过mode参数设置
output = F.grid_sample(img_tensor.unsqueeze(0),
			   grid,align_corners=True)

plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.array(img))
plt.title("original image")

plt.subplot(1,2,2)
plt.imshow(output[0].numpy().transpose(1,2,0))
plt.title("stn transform image")

plt.show()

  • numpy的实现

我们通过numpy来实现STN的后两个模块,来帮助大家更好的理解STN

class Grid_sample(object):
    def affine_grid(self,theta,img_size):
        if len(img_size) != 2:
            assert("img_size size must is 2")
        num_batch = np.shape(theta)[0]
        img_w,img_h = img_size
        #将图片位置归一化到(-1,1)
        x = np.linspace(-1.0,1.0,img_w)
        y = np.linspace(-1.0,1.0,img_h)

        #组合x和y获取到图片的位置坐标
        x_t,y_t = np.meshgrid(x,y)
        x_t_flat = np.reshape(x_t,[-1])
        y_t_flat = np.reshape(y_t,[-1])

        #创建一个图片的位置数组
        ones = np.ones_like(x_t_flat)
        sampling_grid = np.stack([x_t_flat,y_t_flat,ones])
        sampling_grid = np.expand_dims(sampling_grid,axis=0)
        sampling_grid = np.tile(sampling_grid,
                                np.stack([num_batch,1,1]))

        #计算变换后的图片位置
        batch_grids = np.matmul(theta,sampling_grid)
        batch_grids = np.reshape(batch_grids,
                                 [num_batch,2,img_h,img_w])

        return batch_grids


    def bilinear_sampler(self,img,batch_grids):
        if (batch_grids.shape) != 4:
            assert("batch_grids shape is must equal 4")
        #获取变换后图片位置的x和y轴的坐标位置
        x = batch_grids[:, 0, :, :]
        y = batch_grids[:, 1, :, :]

        img_w,img_h = img.shape[:2]
        max_x = img_w - 1
        max_y = img_h - 1

        #将变换后的坐标位置固定到(0,w/h-1)
        x = 0.5 * ((x+1.0)*(max_x-1))
        y = 0.5 * ((y+1.0)*(max_y-1))

        #将坐标位置取整,便于从输入图片中获取位置对应的像素值
        x0 = np.floor(x).astype(np.int)
        x1 = x0 + 1
        y0 = np.floor(y).astype(np.int)
        y1 = y0 + 1

        #防止坐标越界
        x0 = np.clip(x0,0,max_x)
        x1 = np.clip(x1,0,max_x)
        y0 = np.clip(y0,0,max_y)
        y1 = np.clip(y1,0,max_y)

        #根据坐标位置,取像素值
        Ia = img[y0,x0,:]
        Ib = img[y1,x0,:]
        Ic = img[y0,x1,:]
        Id = img[y1,x1,:]

        wa = np.expand_dims((x1-x)*(y1-y),axis=3)
        wb = np.expand_dims((x1-x)*(y-y0),axis=3)
        wc = np.expand_dims((x-x0)*(y1-y),axis=3)
        wd = np.expand_dims((x-x0)*(y-y0),axis=3)

        #利用双线性插值计算变换后的像素值
        out = wa*Ia + wb*Ib + wc*Ic + wd*Id

        return out


grid_sampler = Grid_sample()
img = np.array(Image.open("img/test.jpg"))
img_h,img_w = img.shape[:2]
theta = np.array([[[1, 0, 0.1], [0, 1, 0.2]]],dtype=np.float)
theta = np.expand_dims(theta,axis=0)

batch_grids = grid_sampler.affine_grid(theta,(img_w,img_h))
out = grid_sampler.bilinear_sampler(img,batch_grids)

plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(np.array(img))
plt.title("original image")

plt.subplot(1, 2, 2)
plt.imshow(out[0].astype(np.uint8))
plt.title("stn transform image")

plt.show()


下一篇文章我们介绍如何将STN模块插入到CNN中

以上是关于通俗易懂的Spatial Transformer Networks(STN)的主要内容,如果未能解决你的问题,请参考以下文章

论文笔记Spatial Transformer Networks

stn,spatial transformer network总结

Spatial Transformer Networks(STN)-论文笔记

Spatial Transformer Networks(STN)-代码实现

Spatial Transformer Networks(STN)-代码实现

论文笔记-5Spatial Transformer Networks(STN)