深入理解Character Region Awareness for Text Detection (CRAFT)

Posted SpikeKing

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深入理解Character Region Awareness for Text Detection (CRAFT)相关的知识,希望对你有一定的参考价值。

本文分析CRAFT最重要的网络结构和训练数据。

其余参考:CRAFT字符检测算法和SynthText合成文本数据集


CRAFT网络结构

参考文件:craft.py

逻辑代码:

  • 输入x,[1, 3, 1280, 960],即输入图像尺寸
  • basenet是vgg16_bn,输出5个中间结果的特征图
    • 0: 1x1024x80x60,即缩放16倍
    • 1: 1x512x80x60,即缩放16倍
    • 2: 1x512x160x120,即缩放8倍
    • 3: 1x256x320x240,即缩放4倍
    • 4: 1x128x640x480,即缩放2倍
    • 输出5个特征图,0~4特征图从小到大
  • 第1次up:0+1=[1x1536x80x60],降低通道至256:[1x256x80x60],即upconv1操作;
  • 第2次up:
    • 反卷积操作,[1x256x80x60] -> [1x256x160x120],即F.interpolate操作;
    • 降低通道:[1x256x160x120]+[1x512x160x120]=[1x768x160x120],降低通道至128:[1x128x160x120],即upconv1操作
  • 第3次up、第4次up与第2次相同,最后输出[1x32x640x480]
  • 最终:卷积分类操作,[1x32x640x480]->[1x2x640x480], 输出2通道,一个作为字母特征图,一个作为字母关联特征图。

源码:

def forward(self, x):
    """ Base network """
    sources = self.basenet(x)

    """ U network """
    y = torch.cat([sources[0], sources[1]], dim=1)
    y = self.upconv1(y)

    y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
    y = torch.cat([y, sources[2]], dim=1)
    y = self.upconv2(y)

    y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
    y = torch.cat([y, sources[3]], dim=1)
    y = self.upconv3(y)

    y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
    y = torch.cat([y, sources[4]], dim=1)
    feature = self.upconv4(y)

    y = self.conv_cls(feature)

    return y.permute(0, 2, 3, 1), feature

网络结构图:
CRAFT


ICDAR2015数据集 - 训练数据

ICDAR2015数据集,采样,数据中包含无效数据,mask中设置为0,或低置信度。修改数据逻辑,将batch_size改成1,shuffle改成False,即:

real_data = ICDAR2015(net, ICDAR_2015_PATH, target_size=768)
real_data_loader = torch.utils.data.DataLoader(
    real_data,
    # batch_size=10,
    batch_size=1,  # 测试
    # shuffle=True,
    shuffle=False,  # 测试
    num_workers=0,
    drop_last=True,
    pin_memory=True)

加载数据集的时间比较长,需要打点获取输出结果和格式。在加载图像的gt和置信度mask中,即load_image_gt_and_confidencemask,对于新的数据集,需要覆写这个函数。

  • image_name: img_329.jpg
  • word_bboxes: 12x4x2,即12个框,用于计算字母框
  • words,12个文字:[‘LOVIS’, ‘DIAMONDS’, ‘###’, ‘15%’, ‘###’, ‘LOVS’, ‘###’, ‘SSHA’, ‘###’, ‘###’, ‘###’, ‘###’],只有0、1、3、5、7有明显的文字
    • ###或单词为空时: confidence_mask的值设置为0
    • 其余的单词,通过计算,设置置信度confidence为[0.8, 1.0, 1.0, 0.5, 0.5]
    • confidence_mask的值设置为confidence的值,通过cv2.fillPoly

输出:

  • character_bboxes是5个单词的字母框
  • new_words是[‘LOVIS’, ‘DIAMONDS’, ‘15%’, ‘LOVS’, ‘SSHA’]
  • confidence_mask是全1的置信度图,包含部分是0,或0~1
  • confidences是置信度的值,即[0.8, 1.0, 1.0, 0.5, 0.5]

源码如下:

def load_image_gt_and_confidencemask(self, index):
    '''
    根据索引加载ground truth
    '''
    image_name = self.images_path[index]
    gt_path = os.path.join(self.gt_folder, "gt_%s.txt" % os.path.splitext(image_name)[0])
    word_bboxes, words = self.load_gt(gt_path)
    word_bboxes = np.float32(word_bboxes)

    image_path = os.path.join(self.img_folder, image_name)
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    image = random_scale(image, word_bboxes, self.target_size)

    confidence_mask = np.ones((image.shape[0], image.shape[1]), np.float32)

    character_bboxes = []
    new_words = []
    confidences = []
    if len(word_bboxes) > 0:
        for i in range(len(word_bboxes)):
            if words[i] == '###' or len(words[i].strip()) == 0:
                cv2.fillPoly(confidence_mask, [np.int32(word_bboxes[i])], (0))
        for i in range(len(word_bboxes)):
            if words[i] == '###' or len(words[i].strip()) == 0:
                continue
            pursedo_bboxes, bbox_region_scores, confidence = \\
                self.inference_pursedo_bboxes(self.net, image, word_bboxes[i], words[i], viz=self.viz)
            confidences.append(confidence)
            cv2.fillPoly(confidence_mask, [np.int32(word_bboxes[i])], confidence)
            new_words.append(words[i])
            character_bboxes.append(pursedo_bboxes)
    return image, character_bboxes, new_words, confidence_mask, confidences

测试图像img_329.jpg,12个图像框:
文字框
初始的Mask全是1:
Mask-1
添加负向干扰0的Mask:
Mask-0
添加概率Mask,错误的是0~1的概率:
Mask-01
character_bboxes字母框:
Character Bboxes
根据character_bboxes,计算字符的Mask,region_scores
Region Scores
位于data_loader.py

region_scores = self.gaussianTransformer.generate_region(region_scores.shape, character_bboxes)

affinity_scores, affinity_bboxes = \\
    self.gaussianTransformer.generate_affinity(region_scores.shape, character_bboxes, words)

根据character_bboxes和words,计算相关度的Mask:
Affinity Scores


以上是关于深入理解Character Region Awareness for Text Detection (CRAFT)的主要内容,如果未能解决你的问题,请参考以下文章

深入理解JVM(③)ZGC收集器

深入理解HBase的系统架构

深入理解JVM - G1调优简述

Prism 源码解读1-Bootstrapper和Region的创建

HBase原理深入

深入理解JVM-ZGC垃圾收集器