DETR 端到端目标检测
DETR(DEtection TRansformer)是FAIR于2020年提出的革命性目标检测框架,首次将Transformer成功应用于目标检测任务,实现了真正的端到端检测。本章详细介绍DETR的设计原理、架构实现和关键创新。
一、DETR的核心创新
1.1 传统目标检测的痛点
传统目标检测方法(如Faster R-CNN、YOLO)存在以下问题:
| 问题 | 描述 | 影响 |
|---|---|---|
| 依赖锚框 | 需要预设大量锚框 | 超参数多、难以调优 |
| NMS后处理 | 需要手动去除重复框 | 推理速度慢 |
| 标签分配复杂 | 一对多匹配 | 训练不稳定 |
| 两阶段范式 | RPN + ROI Head | 流程复杂 |
1.2 DETR的解决方案
┌─────────────────────────────────────────────────────────────┐
│ DETR核心创新 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 传统方法: │
│ 特征图 ──▶ RPN ──▶ ROI Pooling ──▶ 分类+回归 ──▶ NMS │
│ │
│ DETR: │
│ 特征图 ──▶ Transformer ──▶ 集合预测 ──▶ 直接输出 │
│ │ │
│ 无需NMS! │
│ │
└─────────────────────────────────────────────────────────────┘
1.3 核心思想
DETR将目标检测重新定义为集合预测问题:
- 端到端训练:直接从图像预测固定数量的目标
- 集合预测:一次性输出所有检测结果,无需后处理
- Transformer架构:利用全局注意力捕获目标间关系
二、DETR架构
2.1 整体架构
┌─────────────────────────────────────────────────────────────┐
│ DETR 架构图 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 输入图像 (H × W × 3) │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ CNN Backbone │ ResNet-50/101 │
│ │ (下采样4倍) │ 输出: H/32 × W/32 × 2048 │
│ └────────┬────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Input Projection │ 1×1卷积降维 │
│ │ (d=256) │ 输出: H/32 × W/32 × 256 │
│ └────────┬────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Transformer Encoder │ │
│ │ - 6层编码器 │ │
│ │ - 多头自注意力 │ │
│ │ - 前馈网络 │ │
│ └────────┬────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Transformer Decoder │ │
│ │ - 6层解码器 │ │
│ │ - N个对象查询 (N=100) │ │
│ │ - 交叉注意力 (查询←编码器) │ │
│ └────────┬────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 预测头 (Prediction Heads) │ │
│ │ - 分类: N × (C+1) 类别概率 │ │
│ │ - 回归: N × 4 边界框 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 集合预测损失 (Set Prediction Loss) │ │
│ │ - Hungarian Matching │ │
│ │ - 分类损失 + L1损失 + GIoU损失 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
2.2 Backbone
class DETRBackbone(nn.Module):
"""CNN Backbone: ResNet-50/101"""
def __init__(self, backbone='resnet50', pretrained=True):
super().__init__()
if backbone == 'resnet50':
self.backbone = resnet50(pretrained=pretrained)
self.num_features = 2048
else:
self.backbone = resnet101(pretrained=pretrained)
self.num_features = 2048
def forward(self, x):
# ResNet特征提取
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
x = self.backbone.layer1(x)
x = self.backbone.layer2(x)
x = self.backbone.layer3(x)
x = self.backbone.layer4(x)
# 输出: [B, 2048, H/32, W/32]
return x2.3 Transformer编码器
class TransformerEncoder(nn.Module):
"""Transformer编码器"""
def __init__(self, d_model=256, nhead=8, num_encoder_layers=6,
dim_feedforward=2048, dropout=0.1):
super().__init__()
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation='relu',
batch_first=True
)
self.encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=num_encoder_layers
)
def forward(self, src, mask=None):
"""
Args:
src: [B, C, H, W] 特征图
mask: [B, H*W] 可选掩码
"""
B, C, H, W = src.shape
# 展平并转置: [B, H*W, C]
src = src.flatten(2).permute(0, 2, 1)
# 位置编码
src = src + self.pos_encoding(src.shape)
# 编码
memory = self.encoder(src, src_key_padding_mask=mask)
return memory # [B, H*W, d_model]2.4 Transformer解码器
class TransformerDecoder(nn.Module):
"""Transformer解码器: N个对象查询"""
def __init__(self, d_model=256, nhead=8, num_decoder_layers=6,
dim_feedforward=2048, dropout=0.1, num_queries=100):
super().__init__()
self.num_queries = num_queries
# 可学习的对象查询 (Query Embeddings)
self.query_embed = nn.Embedding(num_queries, d_model)
self.query_pos = nn.Embedding(num_queries, d_model)
decoder_layer = nn.TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation='relu',
batch_first=True
)
self.decoder = nn.TransformerDecoder(
decoder_layer,
num_layers=num_decoder_layers
)
def forward(self, memory, target=None):
"""
Args:
memory: [B, H*W, d_model] 编码器输出
target: 可选的decoder输入 (训练时使用)
"""
B = memory.shape[0]
# 初始化查询
if target is None:
# 推理时:从零开始
tgt = torch.zeros(B, self.num_queries, memory.shape[-1],
device=memory.device)
else:
tgt = target
# 查询嵌入
query = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)
query_pos = self.query_pos.weight.unsqueeze(0).repeat(B, 1, 1)
# 解码
hs = self.decoder(
tgt + query_pos, # 添加位置编码
memory,
query_key_padding_mask=None,
memory_key_padding_mask=None
)
# 输出: [B, num_queries, d_model]
return hs2.5 预测头与边界框预测
class MLP(nn.Module):
"""多层感知机预测头"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class DETR(nn.Module):
"""完整DETR模型"""
def __init__(self, num_classes=91, num_queries=100, d_model=256):
super().__init__()
self.num_classes = num_classes
self.num_queries = num_queries
# Backbone
self.backbone = DETRBackbone()
# 输入投影
self.input_proj = nn.Conv2d(2048, d_model, kernel_size=1)
# 位置编码
self.pos_encoder = PositionalEncoding(d_model, dropout=0.1)
# Transformer
self.transformer = Transformer(d_model=d_model)
# 预测头
self.class_embed = nn.Linear(d_model, num_classes + 1) # +1 for no object
self.bbox_embed = MLP(d_model, d_model, 4, num_layers=3)
# 边界框初始化
self._reset_parameters()
def _reset_parameters(self):
# 边界框预测初始化为小中心、大的高宽
nn.init.constant_(self.bbox_embed.layers[-1].weight, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias, 0)
def forward(self, images):
# Backbone
features = self.backbone(images) # [B, 2048, H/32, W/32]
# 投影
src = self.input_proj(features) # [B, d_model, H/32, W/32]
# 位置编码
src = self.pos_encoder(src)
# Transformer
hs = self.transformer(src) # [B, num_queries, d_model]
# 预测
outputs_class = self.class_embed(hs) # [B, num_queries, num_classes+1]
outputs_coord = self.bbox_embed(hs).sigmoid() # [B, num_queries, 4]
# 边界框坐标转换 (cx, cy, w, h) -> (x, y, w, h)
# DETR输出的是相对于图像尺寸的归一化坐标
return {
'pred_logits': outputs_class,
'pred_boxes': outputs_coord
}三、集合预测损失
3.1 Hungarian Matching
DETR使用Hungarian算法进行最优匹配:
def hungarian_matching(pred_logits, pred_boxes, target_boxes, target_labels):
"""
Hungarian Matching: 找到预测与GT的最优匹配
Args:
pred_logits: [B, num_queries, num_classes+1]
pred_boxes: [B, num_queries, 4]
target_boxes: [num_GT, 4]
target_labels: [num_GT]
"""
B, num_queries = pred_logits.shape[:2]
# 计算分类损失矩阵
# [B, num_queries, num_classes+1] -> [B, num_queries, num_classes]
pred_probs = pred_logits[..., :-1].softmax(-1) # 去掉no-object类
# 计算bbox损失矩阵
# 使用L1损失和GIoU损失的加权和
cost_bbox = torch.cdist(pred_boxes, target_boxes, p=1)
cost_giou = -box_iou(pred_boxes, target_boxes).log()
# 组合损失
cost_class = -pred_probs[:, :, target_labels]
# [B, num_queries, num_GT]
cost_matrix = cost_class + cost_bbox + cost_giou
# Hungarian算法
indices = []
for b in range(B):
indices_b = linear_sum_assignment(cost_matrix[b].cpu().numpy())
indices.append((torch.tensor(indices_b[0]), torch.tensor(indices_b[1])))
return indices # [(pred_idx, target_idx), ...]3.2 损失函数
def set_criterion(pred_logits, pred_boxes, targets):
"""
DETR损失函数: 分类损失 + 边界框损失
"""
indices = hungarian_matching(pred_logits, pred_boxes, targets)
idx = _get_src_permutation_idx(indices)
# 分类损失: 交叉熵
target_classes = torch.full_like(pred_logits, 0)
target_classes_o = torch.cat([t["labels"] for t in targets])
target_classes[idx] = target_classes_o
loss_ce = F.cross_entropy(
pred_logits.transpose(1, 2),
target_classes,
reduction='mean'
)
# L1损失
loss_bbox = F.l1_loss(
pred_boxes[idx],
torch.cat([t['boxes'] for t in targets], dim=0)[idx[1]],
reduction='mean'
)
# GIoU损失
loss_giou = 1 - torch.diag(box_iou(
pred_boxes[idx],
torch.cat([t['boxes'] for t in targets], dim=0)[idx[1]]
)).mean()
# 总损失
loss = loss_ce + 5 * loss_bbox + loss_giou
return {
'loss_ce': loss_ce,
'loss_bbox': loss_bbox,
'loss_giou': loss_giou,
'loss': loss
}四、实验结果
4.1 COCO目标检测
| 方法 | Backbone | AP | AP50 | AP75 | Params | FLOPs |
|---|---|---|---|---|---|---|
| Faster R-CNN | ResNet-50 | 42.0 | 62.1 | 45.5 | 41M | 180G |
| Faster R-CNN+ | ResNet-101 | 44.0 | 63.9 | 47.8 | 60M | 340G |
| DETR | ResNet-50 | 42.0 | 62.4 | 44.2 | 41M | 86G |
| DETR-DC5 | ResNet-50 | 43.3 | 63.1 | 45.9 | 41M | 187G |
| DETR-DC5 | ResNet-101 | 44.9 | 64.7 | 47.7 | 60M | 360G |
4.2 收敛曲线
┌─────────────────────────────────────────────────────────────┐
│ DETR vs Faster R-CNN 收敛对比 │
│ │
│ AP │
│ 50 ┤ ┌──── Faster R-CNN │
│ 45 ┤ ┌─┤ ┌──── DETR │
│ 40 ┤ ┌─┤ └──┐ └───┐ │
│ 35 ┤ │ └───┐ └──┐ │
│ 30 ┤ └─┘ │
│ └┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┤
│ 0 50 100 150 200 250 300 400 500 │
│ Epochs │
│ │
│ DETR需要更多epoch收敛,但最终性能相当 │
└─────────────────────────────────────────────────────────────┘
4.3 检测结果可视化
DETR的优势:
- 全局上下文:能检测大物体和罕见类别
- 无重复检测:不需要NMS
- 注意力可视化:可解释性强
五、关键洞察
5.1 集合预测的优势
- 避免NMS:直接输出唯一检测结果
- 避免锚框:不需要预设锚框
- 端到端:训练和推理一致
5.2 训练挑战
| 挑战 | 解决方案 |
|---|---|
| 收敛慢 | 长时间训练(500 epochs) |
| 小物体检测差 | 使用多尺度特征(DC5) |
| 边界框预测 | 使用GIoU损失 |
5.3 与传统方法的对比
| 方面 | Faster R-CNN | DETR |
|---|---|---|
| 流程 | 两阶段 | 端到端 |
| 后处理 | NMS | 无 |
| 锚框 | 需要 | 不需要 |
| 注意力 | 无 | 全局注意力 |
| 训练epoch | 12-36 | 300-500 |
六、DETR的后续改进
6.1 Deformable DETR
针对DETR收敛慢的问题:
# Deformable DETR核心改进
class DeformableAttention(nn.Module):
"""可变形注意力:只关注参考点周围的采样点"""
def __init__(self, d_model=256, n_heads=8, n_levels=4, n_points=4):
super().__init__()
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
def forward(self, query, reference_points, input_flatten):
# 计算采样偏移
offsets = self.sampling_offsets(query)
# 在参考点周围采样特定数量的点
# 大幅减少计算量,加速收敛6.2 其他改进方向
| 方法 | 改进点 |
|---|---|
| Conditional DETR | 改进注意力解耦,加速收敛 |
| TSP-RCNN | 结合传统R-CNN |
| UP-DETR | 无监督预训练 |
| DAB-DETR | 使用Anchor Boxes改进查询 |