夯实基础系列:文本识别算法:RARE(Robust Scene Text Recognition with Automatic Rectification)核心代码
Posted Liekkas Kono
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了夯实基础系列:文本识别算法:RARE(Robust Scene Text Recognition with Automatic Rectification)核心代码相关的知识,希望对你有一定的参考价值。
引言
- RARE是基于Attention机制来实现端到端文本识别的算法,具有一定的经典性,对于理解基于注意力机制做文字识别具有很大学习意义。
- RARE 论文地址:link
基本原理
- 本文侧重点在于Attention部分,对于该篇的主要工作STN和SRN两个模块,先不说。
- 从上图框架简单来看,基本结构是CNN + BLSTM + Decoder模块。相比于CRNN,从模块来看,只是解码部分换了。
核心代码
- 代码参考PaddleOCR中RARE实现
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from typing import List
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
class RARE(nn.Module):
def __init__(self) -> None:
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(3, 64, 3, 1, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.MaxPool2d((2, 2)), # 1 x 64 x 16 x 50
nn.Conv2d(64, 128, 3, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.MaxPool2d((2, 2)), # 1 x 128 x 8 x 25
nn.Conv2d(128, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.MaxPool2d((2, 2), (2, 1), (1, 1)), # 1 x 256 x 4 x 25
nn.Conv2d(256, 480, 3, 1, 1),
nn.BatchNorm2d(480),
nn.ReLU(True),
nn.MaxPool2d((2, 2), (2, 1), (0, 1)), # 1 x 480 x 2 x 25
nn.Conv2d(480, 480, 2, 1, 0),
nn.BatchNorm2d(480),
nn.ReLU(True), # 1 x 480 x 1 x 26
)
self.rnn = nn.LSTM(480, 96, bidirectional=True)
self.attention = Attention(192, 38, 96)
def forward(self, x: torch.Tensor):
x = self.cnn(x)
x = x.squeeze(2)
x = x.permute(0, 2, 1) # B x width x channels
x, (h, c) = self.rnn(x) # 1 x 26 x 192
x = self.attention(x) # 1 x 25 x 38
return x
class Attention(nn.Module):
def __init__(self, in_channels, out_channels, hidden_size):
super().__init__()
self.num_steps = 25
self.num_classes = out_channels
self.hidden_size = hidden_size
self.in_channels = in_channels
self.out_channels = out_channels
self.attention_cell = AttentionGRUCell(in_channels,
hidden_size,
out_channels)
self.generator = nn.Linear(hidden_size, out_channels)
def forward(self, x):
batch_size = x.shape[0]
probs = None
hidden = torch.zeros(batch_size, self.hidden_size)
targets = torch.zeros(batch_size, dtype=torch.int64)
for _ in range(self.num_steps):
char_onehots = F.one_hot(targets, num_classes=self.num_classes)
(outputs, hidden), _ = self.attention_cell(hidden, x,
char_onehots)
probs_step = self.generator(outputs)
if probs is None:
probs = torch.unsqueeze(probs_step, dim=1)
else:
probs = torch.concat([probs, torch.unsqueeze(probs_step, dim=1)], dim=1)
targets = probs_step.argmax(dim=1)
# 推理阶段
probs = F.softmax(probs, dim=2)
return probs
class AttentionGRUCell(nn.Module):
def __init__(self, input_size, hidden_size, num_embedding) -> None:
super().__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias=False)
self.gru = nn.GRU(input_size=input_size + num_embedding,
hidden_size=hidden_size,)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
# 这里实现参考论文https://arxiv.org/pdf/1704.03549.pdf
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = torch.unsqueeze(self.h2h(prev_hidden), dim=1)
res = torch.add(batch_H_proj, prev_hidden_proj)
res = F.tanh(res)
e = self.score(res)
alpha = F.softmax(e, dim=1)
alpha = alpha.permute(0, 2, 1)
context = torch.squeeze(torch.matmul(alpha, batch_H), dim=1)
concat_context = torch.concat([context, char_onehots], 1)
cur_hidden = self.gru(concat_context, prev_hidden)
return cur_hidden, alpha
class AttenLabelDecode():
def __init__(self, ) -> None:
self.beg_str = 'sos'
self.end_str = 'eos'
self.character_str = '0123456789abcdefghijklmnopqrstuvwxyz'
dict_character = list(self.character_str)
self.charcter = self.add_special_char(dict_character)
self.dict = char: i for i, char in enumerate(self.charcter)
def add_special_char(self, dict_character: List):
return [self.beg_str] + dict_character + [self.end_str]
def __call__(self, preds: torch.Tensor) -> str:
preds = preds.detach().numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob)
return text
def decode(self, txt_idx, txt_prob) -> List:
result = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(txt_idx)
for batch_idx in range(batch_size):
char_list, conf_list = [], []
for idx in range(len(txt_idx[batch_idx])):
cur_value = txt_idx[batch_idx][idx]
cur_prob = txt_prob[batch_idx][idx]
if cur_value in ignored_tokens:
continue
if int(cur_value) == int(ignored_tokens[1]):
break
char_list.append(self.charcter[int(cur_value)])
conf_list.append(cur_prob)
text = ''.join(char_list)
result.append((text, np.mean(conf_list).tolist()))
return result
def get_ignored_tokens(self,):
beg_idx = np.array(self.dict[self.beg_str])
end_idx = np.array(self.dict[self.end_str])
return beg_idx, end_idx
if __name__ == '__main__':
model = RARE()
decoder = AttenLabelDecode()
x = torch.randn((1, 3, 32, 100))
y = model(x)
print(y.shape)
result = decoder(y)
print(result)
以上是关于夯实基础系列:文本识别算法:RARE(Robust Scene Text Recognition with Automatic Rectification)核心代码的主要内容,如果未能解决你的问题,请参考以下文章