# 端到端目标检测框架DETR

# 背景介绍

DETRFacebook AINicolas Carion等于202005月提交的论文中提出的。

论文地址: https://arxiv.org/abs/2005.12872 (opens new window)
开源代码: https://github.com/facebookresearch/detr (opens new window)

DETR(DEtection TRansformer)将目标检测问题看成是集合预测的问题,所谓集合预测set prediction是指一次输出一张图像中的所有待检测对象。

DETR使用transformer来做目标检测,直接预测检测框到检测框中心点归一化的距离。在模型训练时,Proposal Assignment使用的算法是一对一的匈牙利算法,通过query的方式获取最后的输出。以上介绍的策略,使得DETR实现了目标检测算法的端到端训练,不需要使用NMS和先验anchor

# 模型结构

从上面这个图可以看到DETR的架构相当简单,输入一张图像,直接输出的就是所有的检测框,不需要复杂的编解码,不需要NMS

# 模块解析

# 数据

官方源码中数据定义在CocoDetection类中,这个类继承自torchvision.datasets.CocoDetection只需要传入COCO格式数据集的图像和json标注文件即可,

COCO格式数据集文件夹路径:

.
├── annotations
│   ├── train.json
│   └── val.json
└── images
    ├── train
    └── val

其中,标签文件bounding box的格式为:

left top width height

CoCoDetection类中有一个self.prepare属性,这是一个函数,其中会将ltwh格式的检测框变换成x1y1x2y2格式的检测框。

DETR源码中使用的变换函数不是从torchvision中导入的,而是自定义的,可以看到在Normalize中,不仅处理了图像数据,还将检测框从x1y1x2y2格式变换成了cxcywh格式,并相对于图像的宽高进行了归一化,其值变换到了[0,1]

class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target=None):
        image = F.normalize(image, mean=self.mean, std=self.std)
        if target is None:
            return image, None
        target = target.copy()
        h, w = image.shape[-2:]
        if "boxes" in target:
            boxes = target["boxes"]
            boxes = box_xyxy_to_cxcywh(boxes)
            boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
            target["boxes"] = boxes
        return image, target

# 模型结构

DETR的模型结构其实很简单,先是将图像输入到几层卷积神经网络中得到特征图feature map,然后使用src = src.flatten(2).permute(2, 0, 1)将特征图WH维度拉平将图像变换成长度为L=W*H的序列数据。

根据序列的长度和每个Token的通道数生成位置编码。

feature map生成的序列和位置编码信息相加作为transformer的输入src

除了输入的特征序列之外,还输入了图像数据的掩码src_mask。原因是因为一个batch输入的图像宽高不一定相同,源码中的处理方式是取一个batch中尺寸最大的图像尺寸,其余图像往右下方向补0,最后变成尺寸一致的图像用于计算。这是为了避免padding-0参与计算,需要将src_mask输入到transformer中。

DETR使用的位置编码是针对图像的带mask的二维位置编码

class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        mask = tensor_list.mask
        assert mask is not None
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos

其在x/y单个方向上使用位置编码的方法同标准的transformer,然后再将x,y上的两个位置编码分别进行了合并。

DETR源码中使用的transformertorch.nn.Transformer也不太一样

DETRtransformer中将位置编码信息输入到编码器和解码器的每一层,在encoder中将pos加在输入的feature上组成qk

class Encoder:
    ...
    def forward_post(self,
                     src,
                     src_mask: Optional[Tensor] = None,
                     src_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(src, pos)
        src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

decoder中将pos加在了encoder的输出memory作为k的值,query_postgt相加的值作为q来计算多头注意力:

class Decoder:
    def forward_post(self, tgt, memory,
                     tgt_mask: Optional[Tensor] = None,
                     memory_mask: Optional[Tensor] = None,
                     tgt_key_padding_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

DETR中实现的transformer中还将每层decoder输出都保存下来以计算检测框,用来辅助训练


class DETRTransformerDecoder():

    ...
    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        output = tgt

        intermediate = []

        for layer in self.layers:
            output = layer(output, memory, tgt_mask=tgt_mask,
                           memory_mask=memory_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask,
                           memory_key_padding_mask=memory_key_padding_mask,
                           pos=pos, query_pos=query_pos)
            if self.return_intermediate:
                intermediate.append(self.norm(output))

        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)

        return output.unsqueeze(0)   

transformer输出的特征输入到计算评分和检测框的两支多层感知积网络中就能预测检测框了:

class DETR:
    ...
    def forward(self, x):
        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] # shape: [BATCH, NUM_QUERY, D_MODEL]

        outputs_class = self.class_embed(x)
        outputs_coord = self.box_embed(x).sigmoid()

以上就是模型的整体结构。

模型输出的num_query个预测框和真值框之间的匹配通过匈牙利算法来实现。匈牙利算法会实现预测框和真值框的一对一匹配,避免了对同个对象生成重复的检测框。在使用anchor的检测算法中,为了减轻候选框中正样本和负样本不平衡的问题,通常会使用多个proposal box来预测一个对象,以提升算法的召回率,代价是预测推理时也会对一个对象生成多个预测框,需要使用NMS算法进行处理。

标签匹配使用的代价包括三部分,分别是分类代价,检测框回归相关的L1距离和GIoU

import torch

class HungarianMatcher(torch.nn.Module):
    ...
    @torch.no_grad()
    def forward(self, outputs, targets):
        ...
        cost_class = -out_prob[:, tgt_ids]
        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
        cost_giou = -giou(cxcywh2x1y1x2y2(out_bbox), cxcywh2x1y1x2y2(tgt_bbox))
        all_cost = self.cost_class * cost_class + \
                    self.cost_bbox * cost_bbox + \
                self.cost_giou * cost_giou

最后是模型训练时使用的损失函数,对于目标检测任务,DETRLoss包含2部分,分别是标签类别损失和检测框回归的L1损失和GIoU损失。


loss_ce = torch.nn.functional.cross_entropy(pred_logits.transpose(1, 2),target_classes_all, self.empty_weight)

loss_bbox = torch.nn.functional.l1_loss(src_boxes, target_boxes, reduction='none')
losses = {}

losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(giou(cxcywh2x1y1x2y2(src_boxes),
                            cxcywh2x1y1x2y2(target_boxes)))
losses['loss_giou'] = loss_giou.sum() / num_boxes

# 动手实现DETR

DETR的架构如此简洁,不需要太多的trick,参考DETR源码,很容易自己动手实现DETR目标检测算法。具体的实现见:

https://gitee.com/lx_r/object_detection_task/tree/main/detection/detr (opens new window)

运行程序会自动生成训练数据开始训练,若平台有GPU会自动调用GPU训练,如果没有GPU会使用CPU训练。

上面的实现中,与原始代码有些许不同:

  • 1)使用的是torch.nn中的transformerpos没有加到encoder的输出memory
  • 2)torch.nn中的transformer只给出了最后一层decoder上的输出,没有给出其他层decoder上的输出,所有没有使用辅助损失训练
  • 3)输入的是相同尺寸的方形图像,没有使用输入掩码
(adsbygoogle = window.adsbygoogle || []).push({});