深入理解Character Region Awareness for Text Detection (CRAFT)
Posted SpikeKing
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深入理解Character Region Awareness for Text Detection (CRAFT)相关的知识,希望对你有一定的参考价值。
本文分析CRAFT最重要的网络结构和训练数据。
其余参考:CRAFT字符检测算法和SynthText合成文本数据集
CRAFT网络结构
参考文件:craft.py
逻辑代码:
- 输入x,[1, 3, 1280, 960],即输入图像尺寸
- basenet是
vgg16_bn
,输出5个中间结果的特征图- 0: 1x1024x80x60,即缩放16倍
- 1: 1x512x80x60,即缩放16倍
- 2: 1x512x160x120,即缩放8倍
- 3: 1x256x320x240,即缩放4倍
- 4: 1x128x640x480,即缩放2倍
- 输出5个特征图,0~4特征图从小到大
- 第1次up:0+1=[1x1536x80x60],降低通道至256:[1x256x80x60],即upconv1操作;
- 第2次up:
- 反卷积操作,[1x256x80x60] -> [1x256x160x120],即F.interpolate操作;
- 降低通道:[1x256x160x120]+[1x512x160x120]=[1x768x160x120],降低通道至128:[1x128x160x120],即upconv1操作
- 第3次up、第4次up与第2次相同,最后输出[1x32x640x480]
- 最终:卷积分类操作,[1x32x640x480]->[1x2x640x480], 输出2通道,一个作为字母特征图,一个作为字母关联特征图。
源码:
def forward(self, x):
""" Base network """
sources = self.basenet(x)
""" U network """
y = torch.cat([sources[0], sources[1]], dim=1)
y = self.upconv1(y)
y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
y = torch.cat([y, sources[2]], dim=1)
y = self.upconv2(y)
y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
y = torch.cat([y, sources[3]], dim=1)
y = self.upconv3(y)
y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
y = torch.cat([y, sources[4]], dim=1)
feature = self.upconv4(y)
y = self.conv_cls(feature)
return y.permute(0, 2, 3, 1), feature
网络结构图:
ICDAR2015数据集 - 训练数据
ICDAR2015数据集,采样,数据中包含无效数据,mask中设置为0,或低置信度。修改数据逻辑,将batch_size改成1,shuffle改成False,即:
real_data = ICDAR2015(net, ICDAR_2015_PATH, target_size=768)
real_data_loader = torch.utils.data.DataLoader(
real_data,
# batch_size=10,
batch_size=1, # 测试
# shuffle=True,
shuffle=False, # 测试
num_workers=0,
drop_last=True,
pin_memory=True)
加载数据集的时间比较长,需要打点获取输出结果和格式。在加载图像的gt和置信度mask中,即load_image_gt_and_confidencemask
,对于新的数据集,需要覆写这个函数。
image_name
:img_329.jpg
word_bboxes
: 12x4x2,即12个框,用于计算字母框- words,12个文字:[‘LOVIS’, ‘DIAMONDS’, ‘###’, ‘15%’, ‘###’, ‘LOVS’, ‘###’, ‘SSHA’, ‘###’, ‘###’, ‘###’, ‘###’],只有0、1、3、5、7有明显的文字
- 当
###
或单词为空时: confidence_mask的值设置为0 - 其余的单词,通过计算,设置置信度confidence为[0.8, 1.0, 1.0, 0.5, 0.5]
- 将
confidence_mask
的值设置为confidence的值,通过cv2.fillPoly
- 当
输出:
character_bboxes
是5个单词的字母框new_words
是[‘LOVIS’, ‘DIAMONDS’, ‘15%’, ‘LOVS’, ‘SSHA’]confidence_mask
是全1的置信度图,包含部分是0,或0~1confidences
是置信度的值,即[0.8, 1.0, 1.0, 0.5, 0.5]
源码如下:
def load_image_gt_and_confidencemask(self, index):
'''
根据索引加载ground truth
'''
image_name = self.images_path[index]
gt_path = os.path.join(self.gt_folder, "gt_%s.txt" % os.path.splitext(image_name)[0])
word_bboxes, words = self.load_gt(gt_path)
word_bboxes = np.float32(word_bboxes)
image_path = os.path.join(self.img_folder, image_name)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = random_scale(image, word_bboxes, self.target_size)
confidence_mask = np.ones((image.shape[0], image.shape[1]), np.float32)
character_bboxes = []
new_words = []
confidences = []
if len(word_bboxes) > 0:
for i in range(len(word_bboxes)):
if words[i] == '###' or len(words[i].strip()) == 0:
cv2.fillPoly(confidence_mask, [np.int32(word_bboxes[i])], (0))
for i in range(len(word_bboxes)):
if words[i] == '###' or len(words[i].strip()) == 0:
continue
pursedo_bboxes, bbox_region_scores, confidence = \\
self.inference_pursedo_bboxes(self.net, image, word_bboxes[i], words[i], viz=self.viz)
confidences.append(confidence)
cv2.fillPoly(confidence_mask, [np.int32(word_bboxes[i])], confidence)
new_words.append(words[i])
character_bboxes.append(pursedo_bboxes)
return image, character_bboxes, new_words, confidence_mask, confidences
测试图像img_329.jpg
,12个图像框:
初始的Mask全是1:
添加负向干扰0的Mask:
添加概率Mask,错误的是0~1的概率:
character_bboxes
字母框:
根据character_bboxes
,计算字符的Mask,region_scores
:
位于data_loader.py
region_scores = self.gaussianTransformer.generate_region(region_scores.shape, character_bboxes)
affinity_scores, affinity_bboxes = \\
self.gaussianTransformer.generate_affinity(region_scores.shape, character_bboxes, words)
根据character_bboxes
和words,计算相关度的Mask:
以上是关于深入理解Character Region Awareness for Text Detection (CRAFT)的主要内容,如果未能解决你的问题,请参考以下文章