RotNet 自监督学习预测图像旋转角度

Posted ZSYL

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了RotNet 自监督学习预测图像旋转角度相关的知识,希望对你有一定的参考价值。

论文导读

RotNet 通过预测图像旋转进行自监督学习

这是2018年ICLR发表的一篇论文,被引用超过1100次。论文的想法来源于:如果某人不了解图像中描绘的对象的概念,则他无法识别应用于图像的旋转。

在这篇文章中,我们回顾了巴黎科技大学(University Paris-Est)通过预测图像旋转进行的无监督表示学习。使用RotNet通过训练ConvNets来学习图像特征,以识别应用于作为输入的图像的2d旋转。通过这种方法,无监督的预训练AlexNet模型达到了54.4%的mAP,仅比有监督的AlexNet低2.4点。

图像旋转预测框架

给定四种可能的几何变换,即0、90、180和270度旋转,卷积网络模型F(:)被训练来识别输入的图像应用了哪个旋转。

Fy(Xy) 是模型 F(:) 预测的旋转变换 y 的概率,它的输入是一个已经被旋转变换的图像,输出图片的旋转角度。

为了成功地预测图像的旋转,ConvNet模型必须学习定位图像中的显著目标,识别它们的方向和对象类型,然后将对象方向与原始图像进行关联。

由经过训练的 AlexNet 模型生成的注意力图(a)识别对象(监督)和(b)识别图像旋转(自监督)。

上述注意力图是根据卷积层的每个空间单元的激活幅度计算的,本质上反映了网络将大部分焦点放在何处以对输入图像进行分类。

途中可以看到,监督模型和自监督模型似乎都关注大致相同的图像区域。

旋转拖动验证码解决方案

曾几何时,你是否被一个旋转验证码而困扰,没错今日主题——旋转验证码


当进行模拟登录时,图片验证码是一大难点。

不过有了RotNet,这一问题便迎刃而解旋转拖动验证码解决方案

两种思路

图像旋转考虑两种思路:回归与分类

  • 回归:预测数值结果范围是0-360°.
  • 分类:预测360个类别,模型预测输出哪个类别的概率最大.

定义卷积神经网络训练旋转图片集,进行预测图片旋转的角度。

大数据应用赛

大数据应用赛:计算机视觉在众多的AI中应用广泛,比如自动驾驶、视觉导航、目标检测、目标识别等等,无一不关系到计算机视觉,而图像技术往往能帮助计算机视觉得到提升,比如随机剪裁、随机旋转、图像模糊等等图像手段。图像技术对计算机视觉的重要性则不言而喻,故本次大数据应用赛的赛题为图像扶正挑战。

卷积神经网络

分类代码

# number of convolutional filters to use
nb_filters = 64
# size of pooling area for max pooling
pool_size = (2, 2)
# convolution kernel size
kernel_size = (3, 3)
# number of classes
nb_classes = 360

# model definition
input = Input(shape=(img_rows, img_cols, img_channels))
x = Conv2D(nb_filters, kernel_size, activation='relu')(input)
x = Conv2D(nb_filters, kernel_size, activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.25)(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.25)(x)
x = Dense(nb_classes, activation='softmax')(x)

model = Model(inputs=input, outputs=x)

model.summary()

模型编译

# model compilation
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=[angle_error])

训练参数

# training parameters
batch_size = 128
nb_epoch = 50

回调

# callbacks
checkpointer = ModelCheckpoint(
    filepath=os.path.join(output_folder, model_name + '.hdf5'),
    save_best_only=True
)
early_stopping = EarlyStopping(patience=2)
tensorboard = TensorBoard()

模型训练

# training loop
model.fit_generator(
    RotNetDataGenerator(
        X_train,
        batch_size=batch_size,
        preprocess_func=binarize_images,
        shuffle=True
    ),
    steps_per_epoch=nb_train_samples / batch_size,
    epochs=nb_epoch,
    validation_data=RotNetDataGenerator(
        X_test,
        batch_size=batch_size,
        preprocess_func=binarize_images
    ),
    validation_steps=nb_test_samples / batch_size,
    verbose=1,
    callbacks=[checkpointer, early_stopping, tensorboard]
)

完整代码

"""
@Author: ZS
@CSDN  : https://zsyll.blog.csdn.net/
@Time  : 2021/11/20 10:48
"""
from __future__ import print_function

import os
import sys

from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
from keras.applications.resnet50 import ResNet50
from keras.applications.imagenet_utils import preprocess_input
from keras.models import Model
from keras.layers import Dense, Flatten
from keras.optimizers import SGD

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import angle_error, RotNetDataGenerator
from getImagePath import getPath

data_path = r'./data/image/'
train_filenames, test_filenames = getPath(data_path)

print(len(train_filenames), 'train samples')
print(len(test_filenames), 'test samples')

model_name = 'rotnet_resnet50'

# 分类数量
nb_classes = 360
# input image shape
input_shape = (320, 320, 3)

# 加载基础模型
base_model = ResNet50(weights='imagenet', include_top=False,
                      input_shape=input_shape)

# 添加分类层
x = base_model.output
x = Flatten()(x)
final_output = Dense(nb_classes, activation='softmax', name='fc360')(x)

# 创建新的模型
model = Model(inputs=base_model.input, outputs=final_output)

model.summary()

# 模型编译
model.compile(loss='categorical_crossentropy',
              optimizer=SGD(lr=0.01, momentum=0.9),
              metrics=[angle_error])

# 训练参数
batch_size = 64
nb_epoch = 20

output_folder = 'models'
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# callbacks
monitor = 'val_angle_error'
checkpointer = ModelCheckpoint(
    filepath=os.path.join(output_folder, model_name + '.hdf5'),
    monitor=monitor,
    save_best_only=True
)

reduce_lr = ReduceLROnPlateau(monitor=monitor, patience=3)
early_stopping = EarlyStopping(monitor=monitor, patience=5)
tensorboard = TensorBoard()

# 训练模型
model.fit_generator(
    RotNetDataGenerator(
        train_filenames,
        input_shape=input_shape,
        batch_size=batch_size,
        preprocess_func=preprocess_input,
        crop_center=True,
        crop_largest_rect=True,
        shuffle=True
    ),
    steps_per_epoch=len(train_filenames) / batch_size,
    epochs=nb_epoch,
    validation_data=RotNetDataGenerator(
        test_filenames,
        input_shape=input_shape,
        batch_size=batch_size,
        preprocess_func=preprocess_input,
        crop_center=True,
        crop_largest_rect=True
    ),
    validation_steps=len(test_filenames) / batch_size,
    callbacks=[checkpointer, reduce_lr, early_stopping, tensorboard],
    workers=10
)

模型调用

# import区域,sys为必须导入,其他根据需求导入
from __future__ import print_function
import os
import sys
import random
import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
import tensorflow.keras as keras



import matplotlib.pyplot as plt
from mykeras.applications.imagenet_utils import preprocess_input
from mykeras.models import load_model
from utils import display_examples, RotNetDataGenerator, angle_error
import warnings
warnings.filterwarnings("ignore")
from tensorflow.keras import layers

# 代码区,根据需求写
class FileSequence(keras.utils.Sequence):
    def __init__(self,filenames,batch_size,filefunc,fileargs=(),labels=None,labelfunc=None,labelargs=(),shuffle=False):
        if labels: assert len(filenames) == len(labels)
        self.filenames  = filenames
        self.batch_size = batch_size
        self.filefunc   = filefunc
        self.fileargs   = fileargs
        self.labels     = labels
        self.labelfunc  = labelfunc
        self.labelargs  = labelargs  
        if shuffle:
            idx_list = list(range(len(self.filenames)))
            random.shuffle(idx_list)
            self.filenames = [self.filenames[idx] for idx in idx_list]
            if self.labels: self.labels = [self.labels[idx] for idx in idx_list]

    def __len__(self):
        return int(np.ceil(len(self.filenames) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_filenames = self.filenames[idx * self.batch_size: (idx+1) * self.batch_size]
        
        files = []
        for filename in batch_filenames:
            # tf.print(filename)
            file = self.filefunc(filename,*self.fileargs)
            files.append(file)
        if self.labels:
            batch_labels = self.labels[idx * self.batch_size: (idx+1) * self.batch_size]
            if self.labelfunc:
                return np.array(files), self.labelfunc(batch_labels,*self.labelargs)
            else:
                return np.array(files), batch_labels
        else:
            return np.array(files)

def fillWhite(img,size,mode=None):
    if len(img.shape) == 2: img = img.reshape(*img.shape,-1)
    assert len(img.shape) == 3
    h, w, c = img.shape
    assert (h < size) and (w < size)
    fillImg = np.zeros(shape=(size,size,c))
    if mode == "random":
        sh = random.randint(0,size-h)
        sw = random.randint(0,size-w)
        fillImg[sh:sh+h,sw:sw+w,...] = img
    elif mode == "centre" or mode == "center":
        fillImg[(size-h)//2:(size+h)//2,(size-w)//2:(size+w)//2,...] = img
    else:
        fillImg[:h,:w,...] = img
    return fillImg

def cropImg(img,size,mode=None):
    if len(img.shape) == 2: img = img.reshape(*img.shape,-1)
    assert len(img.shape) == 3
    h, w, c = img.shape
    assert (h >= size) and (w >= size)
    if mode == "random":
        sh = random.randint(0,h-size)
        sw = random.randint(0,w-size)
        cropImg = img[sh:sh+size,sw:sw+size,...]
    elif mode == "centre" or mode == "center":
        cropImg = img[(h-size)//2:(h+size)//2,(w-size)//2:(w+size)//2,...]
    else:
        cropImg = img[:size,:size,...]
    return cropImg

def fillCrop(img,size,mode=None):
    if len(img.shape) == 2: img = img.reshape(*img.shape,-1)
    assert len(img.shape) == 3
    h, w, c = img.shape
    assert ((h >= size) and (w < size)) or ((h < size) and (w >= size))
    fillcropImg = np.zeros(shape=(size,size,c))
    if mode == "random":
        if (h >= size) and (w < size):
            sh = random.randint(0,h-size)
            sw = random.randint(0,size-w)
            fillcropImg[:,sw:sw+w,:] = img[sh:sh+size,...]
        else:
            sh = random.randint(0,size-h)
            sw = random.randint(0,w-size)
            fillcropImg[sh:sh+h,...] = img[:,sw:sw+size,:]
    elif mode == "centre" or mode == "center":
        if (h >= size) and (w < size):
            fillcropImg[:,(size-w)//2:(size+w)//2,:] = img[(h-size)//2:(h+size)//2,...]
        else:
            fillcropImg[(size-h)//2:(size+h)//2,...] = img[:,(w-size)//2:(w+size)//2,:]
    else:
        if (h >= size) and (w < size):
            fillcropImg[:,:size,:] = img[:size,...]
        else:
            fillcropImg[:size,...] = img[:,:size,:]
    return fillcropImg

def resizeImg(img,size,mode=None):
    if len(img.shape) == 2: img = img.reshape(*img.shape,-1)
    assert len(img.shape) == 3
    h, w, c = img.shape
    if (h < size) and (w < size): return fillWhite(img,size,mode)
    elif (h >= size) and (w >= size): return cropImg(img,size,mode)
    else: return fillCrop(img,size,mode)

def filefunc(filename,mode):
    tf.print(filename)
    img = cv2.imread(filename)
    if not isinstance(img,np.ndarray):
        tf.print(filename)
    h, w, c = img.shape
    if (h >=256) or (w >= 256):
        img = resizeImg(以上是关于RotNet 自监督学习预测图像旋转角度的主要内容,如果未能解决你的问题,请参考以下文章

机器视觉学习笔记最近邻插值实现图片任意角度旋转(C++)

监督学习方法解决时序预测问题

监督学习方法解决时序预测问题

Matlab 使用CNN拟合回归模型预测手写数字的旋转角度

OpenCV环境下实现图像任意角度旋转的原理及代码

机器学习工程师 - Udacity 监督学习