# 基于分割的文本检测算法--DBNet
# 1.基本情况
论文:Real-time Scene Text Detection with Differentiable Binarization (opens new window)
代码:DB (opens new window)
2019年11月华中科技大学的Xiang Bai
等提出的方法。
基于分割的文本检测方法对分割结果的概率图进行二值化后处理,然后来提取文本区域,可以检测任意形状的文本区域。但基于分割的文本检测算法一般都需要复杂的后处理,影响推理的性能。
上图中,蓝色的路径表示传统的基于分割的文本检测,完整流程包括得到分割概率图,使用阈值二值化,然后通过像素聚类等手段得到最终的文本检测结果,红色路径是作者提出的新的方法,同时输出分割概率图和进行二值化使用的阈值图,之后使用一个可微分的二值化操作求得二值化的图像,其中虚线表示操作只发生在预测阶段,实线表示在训练和预测阶段都会发生。阅读源码可以发现,与上图中描述不同,训练阶段的二值化结果是通过可微分的二值化操作得到的,预测阶段的二值化结果仍然使用的是固定阈值来计算的。
在这篇论文中,作者主要的创新点就是提出了可微分二值化运算(Differentiable Binarization, DB),DB的引入使得在训练时可以将二值化操作放入模型中,从而实现模型的端到端训练,简化后处理,加快运算速度。
# 2.主要工作
# 2.1 模型架构
从上图中可以看到网络使用了全卷积结构,将多个尺度的特征图使用FPN直接进行融合,经过上采样得到同样大小的特征图进行concatenate
拼接,经过两个分支,一个输出分割概率图,一个输出阈值图,使用这两个结果,输入到DB
运算中得到近似二值图,对二值图处理得到文本区域。
# 2.2 二值化
记backbone
提取的特征图为
标准二值化:给定表示分割结果的概率图
上式中
可微分二值化:从公式可以看出标准二值化是不可微的,因此使用标准二值化在网络的训练中不能直接对其进行优化。作者提出了可微分二值化运算:
上图(a)中表示的是标准二值化
为什么可微分二值化有助于网络的学习呢?
定义
对
上图中(b)和(c)分别表示
# 2.3 自适应阈值
从上面的介绍中可看到自适应阈值图和文本边框有些相似,但自适应阈值图通过有监督或无监督的训练都能获得。
上图中(a)原图,(b)是分割结果的概率图,(c)是无监督得到的阈值图,(d)是有监督训练得到的阈值图
# 2.4 标签生成
文本的阈值图中的边框不应该是文本的一部分,在进行网络训练之前需先将文本标签轮廓进行扩展,扩展的距离通过以下公式计算:
pyclipper.PyclipperOffset()
,具体可参考 (opens new window)。
因为训练时对标签的处理,在推理时还需要将扩展的部分缩减掉。
# 2.5 模型优化
总的损失函数表示为:
其中,
对于概率图和二值化图使用的是二分类交叉熵损失函数:
其中Hard Negative Mining
算法得到的采样样本,正负样本比例为
# 3.源码
第一部分介绍时提到,DB
只应用在训练时,推理时使用的还是常规的固定阈值二值化方法,从模型的forward
函数可以看到:
decoders/seg_detector.py
class SegDetector(nn.Module):
...
def forward(self, features, gt=None, masks=None, training=False):
c2, c3, c4, c5 = features
in5 = self.in5(c5)
in4 = self.in4(c4)
in3 = self.in3(c3)
in2 = self.in2(c2)
out4 = self.up5(in5) + in4 # 1/16
out3 = self.up4(out4) + in3 # 1/8
out2 = self.up3(out3) + in2 # 1/4
p5 = self.out5(in5)
p4 = self.out4(out4)
p3 = self.out3(out3)
p2 = self.out2(out2)
fuse = torch.cat((p5, p4, p3, p2), 1)
# this is the pred module, not binarization module;
# We do not correct the name due to the trained model.
binary = self.binarize(fuse)
if self.training:
result = OrderedDict(binary=binary)
else:
return binary
if self.adaptive and self.training:
if self.serial:
fuse = torch.cat(
(fuse, nn.functional.interpolate(
binary, fuse.shape[2:])), 1)
thresh = self.thresh(fuse)
thresh_binary = self.step_function(binary, thresh)
result.update(thresh=thresh, thresh_binary=thresh_binary)
return result
def step_function(self, x, y):
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))