# FCOS论文及源码解读
# 1.介绍
论文:《FCOS: Fully Convolutional One-Stage Object Detection》
是澳洲阿德莱德大学的Zhi Tian等最早于2019年04月提交的工作成果,发表在ICCV上。
FCOS
是全卷积实现的Anchor Free
的一阶目标检测器,避免了训练过程中Anchor
相关的计算,减少的训练时的计算量和内存占用,移除了anchor
相关的一系列超参数。
Anchor Based
方法的缺点:
- 检测性能对
anchor
的size
/aspect ratio
/数量
比较敏感。 - 实际对象的检测框大小分布较广泛,
anchor
不一定能覆盖 - 为了得到高召回率,
anchor based
的方法返回了非常多的anchor box
,如FPN
中,输入短边为800
的图像将总共生成大于180K
个Anchor Box
。超级多的anchor box
除了影响性能外,还导致了严重的类别不平衡问题,因为180K
个anchor box
中有大量的都是不包含对象的negative box
neat fully convolution pixel prediction framework
因anchor box
的存在不适用于object detection
。先前的预测方法在目标检测框重叠时检测效果不好,如DenseBox
。FCOS
在距离目标中心较远的位置产生很多低质量的预测边框,这会影响检测的效果,为了克服这个问题,引入了中心度
的概念,衡量预测框到距目标框的距离。FCOS
添加单层分支,与分类分支并行,以预测"Center-ness"
位置。
FCOS
的优点:
- 全
FCN
结构,与pixel prediction
任务的网络统一,更便于复用语义分割中的tricks
anchor free
框架,减少模型的设计参数。anchor free
框架,减少训练中的IoU计算和box match
.FCOS
可用于two-stage
检测网络中的RPN
- 易于拓展应用于其他视觉任务,如
Instance Segmentation
等
Anchor Free
目标检测器有YoloV1
,CornerNet
。
# 2.FCOS中使用的方法
# 2.1 网络结构
将目标检测任务形式化为pixel prediction
任务,使用多级预测提升召回率,解决重叠目标的模糊问题。
特征图上的点(x,y)
可以重新映射到输入原图上:
FCOS
直接将(x, y)
对应的检测框位置当作训练样本,具体是指,当(x,y)
映射到原图上落入任何ground-truth box
中,这个点就当成正样本,否则就是负样本。检测定位回归的变量是(x, y)
距检测框左/上/右/下四条边的距离。
当feature map
上的某个点(x,y)
同时落入多个bounding box
中时,这个点(x,y)
被当成ambiguous sample
, FPN
的multi level
机制可以用来解决这个问题。很显然,通过上述介绍可以知道,FCOS
中可能有多个feature map
上的点(x, y)
落入到同个ground truth box
中,故可以生成很多positive boxes
用于训练。
网络输出
训练了C
个二分类器,而非1
个多分类器,可以实现多标签预测。feature map
后接有4个卷积层的分类分支
和位置回归分支
,回归分支使用了exp(x)
,将回归预测变量x
变换到了(0, +\infinity)
上,输出的变量比anchor based
方法少9
倍。
损失函数
分类使用的是Focal Loss
解决类别不平衡问题,回归使用的是IoU
损失。
# 2.2FCOS中使用FPN
的多级预测
FCOS
中的两个问题:
- 1) 到最终输出的
feature map
使用大stride
如x16
会导致低BPR(best possible recall)
。Anchor Based
方法可以对positive box
使用低的IoU
阈值来补偿这个问题,而对于FCOS
,直观的一个猜想是,对于小物体,stride
过大时,可能导致feature map
上并没有一个点可以与小目标的中心对应,故BPR
应该会比较低。 - 2) 重叠的目标框引入了棘手的模糊性问题,
feature map
中的一个点(x, y)
如何确定应该用来回归重叠框中的哪一个呢?
使用FPN
的多级预测解决FCOS
中存在的这两个问题。
同FPN
,使用不同层级中的feature map
来预测不同大小的目标。限制不同level feature map
回归距离level i
的feature map
上每个点回归的最大距离,因此若level i
的feature map
上某个点回归的距离negative box
对于
前面介绍的因预测的是距离,故exp
函数,现在不同level
的回归距离范围是不同的,再使用相同的head
就不合理了,增加一个参数level
使用不同的
# 2.3FCOS中的中心度
可能有多个feature map
中的点对应同个物体检测框,而距离中心较远的feature map
点会导致引入很多低质量的物体检测框,影响检测的效果。
FCOS
引入了中心度centerness
来描述一个点距离目标框中心的远近,以过滤掉偏离中心点的低质量检测框。中心度是通过一个与分类分支并行的单层分支来预测的。对于某个位置的回归目标
sqrt
运算可以减缓中心度的衰减速度。centerness
取值范围0-1
,训练使用二分类交叉熵BCE
,加到前述的损失函数中一起训练,预测时,将centerness
与分类评分score
相乘后作为最后检测框的评分,再进行NMS
,因此很多偏离中心的框就能被过滤掉了。
# 3.mmdetection中FCOS
源码
FCOS
网络结构的定义如上图中二所示,定义的文件在mmdetection
工程``文件中。
看一下FCOS Head
(opens new window),从计算loss
是使用的self.get_targets
方法开始。
graph TD
A(feat_points) --> C(bbox_targets)
B(gt_boxes) --> C
D(center_sample_radius) --> e(inside_gt_bbox_mask)
A-->e
B-->e
f(regress_ranges)
C-->g
f-->g(inside_regress_range)
B-->h(areas)
g-->h
e-->h
h--min-->i(min_area_inds)
j(labels)
i --> k(labels)
j -->k
C-->L(bbox_targets)
i -->L
def _get_target_single(self, gt_bboxes, gt_labels, points, regress_ranges,
num_points_per_lvl):
"""Compute regression and classification targets for a single image."""
num_points = points.size(0)
num_gts = gt_labels.size(0)
if num_gts == 0:
return gt_labels.new_full((num_points,), self.num_classes), \
gt_bboxes.new_zeros((num_points, 4))
areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
gt_bboxes[:, 3] - gt_bboxes[:, 1])
# TODO: figure out why these two are different
# areas = areas[None].expand(num_points, num_gts)
areas = areas[None].repeat(num_points, 1)
regress_ranges = regress_ranges[:, None, :].expand(
num_points, num_gts, 2)
gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
xs, ys = points[:, 0], points[:, 1]
xs = xs[:, None].expand(num_points, num_gts)
ys = ys[:, None].expand(num_points, num_gts)
left = xs - gt_bboxes[..., 0]
right = gt_bboxes[..., 2] - xs
top = ys - gt_bboxes[..., 1]
bottom = gt_bboxes[..., 3] - ys
bbox_targets = torch.stack((left, top, right, bottom), -1)
if self.center_sampling:
# condition1: inside a `center bbox`
radius = self.center_sample_radius
center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2
center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2
center_gts = torch.zeros_like(gt_bboxes)
stride = center_xs.new_zeros(center_xs.shape)
# project the points on current lvl back to the `original` sizes
lvl_begin = 0
for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
lvl_end = lvl_begin + num_points_lvl
stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
lvl_begin = lvl_end
x_mins = center_xs - stride
y_mins = center_ys - stride
x_maxs = center_xs + stride
y_maxs = center_ys + stride
center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0],
x_mins, gt_bboxes[..., 0])
center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1],
y_mins, gt_bboxes[..., 1])
center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2],
gt_bboxes[..., 2], x_maxs)
center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3],
gt_bboxes[..., 3], y_maxs)
cb_dist_left = xs - center_gts[..., 0]
cb_dist_right = center_gts[..., 2] - xs
cb_dist_top = ys - center_gts[..., 1]
cb_dist_bottom = center_gts[..., 3] - ys
center_bbox = torch.stack(
(cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
else:
# condition1: inside a gt bbox
inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
# condition2: limit the regression range for each location
max_regress_distance = bbox_targets.max(-1)[0]
inside_regress_range = (
(max_regress_distance >= regress_ranges[..., 0])
& (max_regress_distance <= regress_ranges[..., 1]))
# if there are still more than one objects for a location,
# we choose the one with minimal area
areas[inside_gt_bbox_mask == 0] = INF
areas[inside_regress_range == 0] = INF
min_area, min_area_inds = areas.min(dim=1)
labels = gt_labels[min_area_inds]
labels[min_area == INF] = self.num_classes # set as BG
bbox_targets = bbox_targets[range(num_points), min_area_inds]
return labels, bbox_targets
中心度的计算:
def centerness_target(self, pos_bbox_targets):
"""Compute centerness targets.
Args:
pos_bbox_targets (Tensor): BBox targets of positive bboxes in shape
(num_pos, 4)
Returns:
Tensor: Centerness target.
"""
# only calculate pos centerness targets, otherwise there may be nan
left_right = pos_bbox_targets[:, [0, 2]]
top_bottom = pos_bbox_targets[:, [1, 3]]
if len(left_right) == 0:
centerness_targets = left_right[..., 0]
else:
centerness_targets = (
left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
return torch.sqrt(centerness_targets)
Focal Loss的定义:
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight