Paper | End-to-End Object Detection with Transformers
TL;DR
本文介绍了一种基于 Transformer 的端到端目标检测方法DETR。与传统的方法不同,DETR 直接预测对象的边界框和类别标签,避免了繁琐的后处理步骤。文章详细介绍了 DETR 的损失函数和模型结构,并通过实验证明了 DETR 在 COCO 数据集上与 Fastee R-CNN 相当甚至更好的性能。此外,DETR 的设计还可以扩展到其他复杂任务,如全景分割
1. Introduction
object detection 的目的是在图像中预测每个感兴趣对象的边界框和类别标签。现有的绝大部分方法通过使用大量的 proposals、anchors 或 window centers 来对问题进行定义,并间接地解决预测任务。这些方法严重依赖后处理步骤,比如非极大值抑制之类的方法。本文的方法是一个端到端的方法,通过直接预测来绕过很多繁琐的步骤
文章通过将 object dection 视为直接集合预测问题来简化训练流程,并采用基于 Transformer 的编/解码器架构进行预测。由于 Transformer 的自注意力机制明确地模拟了序列中元素之间的所有成对交互,使得这些架构特别适合集合预测的特定约束,比如删除重复的预测
上图是 DETR 一次前向过程的简化。DETR 通过 CNN 与 Transformer 架构相结合来并行预测最终的检测集合。在训练过程中,使用二分匹配唯一地将预测与 ground truth 框分配。没有匹配的预测应该产生“无对象”(
文章还在 COCO 上和 Faster R-CNN 进行对比实验,并发现 DETR 可以与之达到相当的结果,尤其是在大型物体上表现出更好的性能,这可能是由于 Transformer 的非局部性引起的。不过,DETR 在小物体上的性能较低,这也是后续可改进的方向
DETR 的设计可以轻松扩展到更复杂的任务。比如,在预训练的 DETR 之上训练的简单分割头优于全景分割(Panoptic Segmentation)上的 baseline 方法
2. Related work
3. The DETR model
3.1 Object detection set prediction loss
文章首先介绍的不是模型的结构,而是损失函数,可见这个损失函数对于整个算法的作用。先前的 object detection 方法会对同一个对象生成多个不同的候选框,需要使用非极大值抑制1选择最优的候选框,这样的后处理不但提高了算法的复杂度,还增加了模型调优的难度。而 DETR 是一个端到端方法,一次性就能得到正确的候选框和分类结果,这也要归功于 DETR 的损失函数设计
首先,DETR 的结果是由
这里
上式中
Bounding box loss:文章直接使用 box prediction。虽然这样简化了实现,但是产生缩放损失的问题,为了缓解这个现象,还要使用 IoU 损失。所以,最终的 box loss为:
3.2 DETR architecture
DETR 的结构简单直接,其实就是直接将 Transformer 编/解码器集成进来。先用一个 CNN 抽取特征(那个时候 ViT 还没出来),将特征展平并送入 Transformer Encoder-Decoder 与 object queries 进行交叉注意力操作,最后通过 FFN 就可得到最终预测的类别和物体框。注意在 decoder 中,所有
Auxiliary decoding losses:在训练过程中在解码器中使用辅助损失是有帮助的,特别是为了帮助模型输出每个类别的正确对象数量。作者在每个解码器层之后添加预测 FFN(Feed-Forward Network)和匈牙利损失。所有预测 FFN 共享参数。使用额外的共享层归一化器来将来自不同解码器层的预测 FFN 的输入进行归一化。文章在附录还贴心的给出了 DETR 的 Pytorch 一次前向的实现,虽然与最终的实现相比简化了很多,但仍可一窥 DETR 的简洁
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(self, num_classes, hidden_dim, nheads,
num_encoder_layers, num_decoder_layers):
super().__init__()
# We take only convolutional layers from ResNet-50 model
self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)
self.transformer = nn.Transformer(hidden_dim, nheads,
num_encoder_layers, num_decoder_layers)
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
self.linear_bbox = nn.Linear(hidden_dim, 4)
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, inputs):
x = self.backbone(inputs)
h = self.conv(x)
H, W = h.shape[-2:]
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
self.query_pos.unsqueeze(1))
return self.linear_class(h), self.linear_bbox(h).sigmoid()
detr = DETR(num_classes=91, hidden_dim=256, nheads=8,
num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)
4. Experiments
4.1 Comparison with Faster R-CNN
Training the baseline model for 300 epochs on 16 V100 GPUs takes 3 days, with 4 images per GPU (hence a total batch size of 64).