夯实基础系列:文本识别算法:RARE(Robust Scene Text Recognition with Automatic Rectification)核心代码

Posted Liekkas Kono

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了夯实基础系列:文本识别算法:RARE(Robust Scene Text Recognition with Automatic Rectification)核心代码相关的知识,希望对你有一定的参考价值。

引言

  • RARE是基于Attention机制来实现端到端文本识别的算法,具有一定的经典性,对于理解基于注意力机制做文字识别具有很大学习意义。
  • RARE 论文地址:link

基本原理

  • 本文侧重点在于Attention部分,对于该篇的主要工作STN和SRN两个模块,先不说。
  • 从上图框架简单来看,基本结构是CNN + BLSTM + Decoder模块。相比于CRNN,从模块来看,只是解码部分换了。
resize CNN RNN Attention Decode input_image 114x34 1x3x32x100 1x480x1x25 1x25x(96x2) 1x25x38(num_classes) Text

核心代码

# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from typing import List

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn


class RARE(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.MaxPool2d((2, 2)),  # 1 x 64 x 16 x 50

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.MaxPool2d((2, 2)),  # 1 x 128 x 8 x 25

            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.MaxPool2d((2, 2), (2, 1), (1, 1)),  # 1 x 256 x 4 x 25

            nn.Conv2d(256, 480, 3, 1, 1),
            nn.BatchNorm2d(480),
            nn.ReLU(True),

            nn.MaxPool2d((2, 2), (2, 1), (0, 1)),  # 1 x 480 x 2 x 25

            nn.Conv2d(480, 480, 2, 1, 0),
            nn.BatchNorm2d(480),
            nn.ReLU(True),  # 1 x 480 x 1 x 26
        )

        self.rnn = nn.LSTM(480, 96, bidirectional=True)

        self.attention = Attention(192, 38, 96)

    def forward(self, x: torch.Tensor):
        x = self.cnn(x)

        x = x.squeeze(2)
        x = x.permute(0, 2, 1)  # B x width x channels

        x, (h, c) = self.rnn(x)  # 1 x 26 x 192

        x = self.attention(x)  # 1 x 25 x 38

        return x


class Attention(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_size):
        super().__init__()
        self.num_steps = 25
        self.num_classes = out_channels
        self.hidden_size = hidden_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.attention_cell = AttentionGRUCell(in_channels,
                                               hidden_size,
                                               out_channels)
        self.generator = nn.Linear(hidden_size, out_channels)

    def forward(self, x):
        batch_size = x.shape[0]
        probs = None
        hidden = torch.zeros(batch_size, self.hidden_size)
        targets = torch.zeros(batch_size, dtype=torch.int64)
        for _ in range(self.num_steps):
            char_onehots = F.one_hot(targets, num_classes=self.num_classes)
            (outputs, hidden), _ = self.attention_cell(hidden, x,
                                                       char_onehots)
            probs_step = self.generator(outputs)
            if probs is None:
                probs = torch.unsqueeze(probs_step, dim=1)
            else:
                probs = torch.concat([probs, torch.unsqueeze(probs_step, dim=1)], dim=1)
            targets = probs_step.argmax(dim=1)

        # 推理阶段
        probs = F.softmax(probs, dim=2)
        return probs


class AttentionGRUCell(nn.Module):
    def __init__(self, input_size, hidden_size, num_embedding) -> None:
        super().__init__()

        self.i2h = nn.Linear(input_size, hidden_size, bias=False)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.score = nn.Linear(hidden_size, 1, bias=False)

        self.gru = nn.GRU(input_size=input_size + num_embedding,
                          hidden_size=hidden_size,)
        self.hidden_size = hidden_size

    def forward(self, prev_hidden, batch_H, char_onehots):
        # 这里实现参考论文https://arxiv.org/pdf/1704.03549.pdf
        batch_H_proj = self.i2h(batch_H)
        prev_hidden_proj = torch.unsqueeze(self.h2h(prev_hidden), dim=1)

        res = torch.add(batch_H_proj, prev_hidden_proj)
        res = F.tanh(res)
        e = self.score(res)

        alpha = F.softmax(e, dim=1)
        alpha = alpha.permute(0, 2, 1)
        context = torch.squeeze(torch.matmul(alpha, batch_H), dim=1)
        concat_context = torch.concat([context, char_onehots], 1)

        cur_hidden = self.gru(concat_context, prev_hidden)
        return cur_hidden, alpha


class AttenLabelDecode():
    def __init__(self, ) -> None:
        self.beg_str = 'sos'
        self.end_str = 'eos'
        self.character_str = '0123456789abcdefghijklmnopqrstuvwxyz'

        dict_character = list(self.character_str)
        self.charcter = self.add_special_char(dict_character)

        self.dict = char: i for i, char in enumerate(self.charcter)

    def add_special_char(self, dict_character: List):
        return [self.beg_str] + dict_character + [self.end_str]

    def __call__(self, preds: torch.Tensor) -> str:
        preds = preds.detach().numpy()

        preds_idx = preds.argmax(axis=2)
        preds_prob = preds.max(axis=2)

        text = self.decode(preds_idx, preds_prob)
        return text

    def decode(self, txt_idx, txt_prob) -> List:
        result = []
        ignored_tokens = self.get_ignored_tokens()
        batch_size = len(txt_idx)
        for batch_idx in range(batch_size):
            char_list, conf_list = [], []
            for idx in range(len(txt_idx[batch_idx])):
                cur_value = txt_idx[batch_idx][idx]
                cur_prob = txt_prob[batch_idx][idx]
                if cur_value in ignored_tokens:
                    continue

                if int(cur_value) == int(ignored_tokens[1]):
                    break

                char_list.append(self.charcter[int(cur_value)])
                conf_list.append(cur_prob)

        text = ''.join(char_list)
        result.append((text, np.mean(conf_list).tolist()))
        return result

    def get_ignored_tokens(self,):
        beg_idx = np.array(self.dict[self.beg_str])
        end_idx = np.array(self.dict[self.end_str])
        return beg_idx, end_idx


if __name__ == '__main__':
    model = RARE()

    decoder = AttenLabelDecode()

    x = torch.randn((1, 3, 32, 100))

    y = model(x)

    print(y.shape)

    result = decoder(y)
    print(result)

以上是关于夯实基础系列:文本识别算法:RARE(Robust Scene Text Recognition with Automatic Rectification)核心代码的主要内容,如果未能解决你的问题,请参考以下文章

夯实Java基础系列目录

基础夯实:基础数据结构与算法

JavaScript夯实基础系列:原型

夯实基础系列一:Java 基础总结

JavaScript夯实基础系列:闭包

并发技术系列「多线程并发编程」技术体系和并发模型的基础探究(夯实基础)