DETR 端到端目标检测

DETR(DEtection TRansformer)是FAIR于2020年提出的革命性目标检测框架,首次将Transformer成功应用于目标检测任务,实现了真正的端到端检测。本章详细介绍DETR的设计原理、架构实现和关键创新。

一、DETR的核心创新

1.1 传统目标检测的痛点

传统目标检测方法(如Faster R-CNN、YOLO)存在以下问题:

问题描述影响
依赖锚框需要预设大量锚框超参数多、难以调优
NMS后处理需要手动去除重复框推理速度慢
标签分配复杂一对多匹配训练不稳定
两阶段范式RPN + ROI Head流程复杂

1.2 DETR的解决方案

┌─────────────────────────────────────────────────────────────┐
│                    DETR核心创新                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   传统方法:                                                  │
│   特征图 ──▶ RPN ──▶ ROI Pooling ──▶ 分类+回归 ──▶ NMS      │
│                                                             │
│   DETR:                                                     │
│   特征图 ──▶ Transformer ──▶ 集合预测 ──▶ 直接输出          │
│                         │                                   │
│                    无需NMS!                                 │
│                                                             │
└─────────────────────────────────────────────────────────────┘

1.3 核心思想

DETR将目标检测重新定义为集合预测问题

  1. 端到端训练:直接从图像预测固定数量的目标
  2. 集合预测:一次性输出所有检测结果,无需后处理
  3. Transformer架构:利用全局注意力捕获目标间关系

二、DETR架构

2.1 整体架构

┌─────────────────────────────────────────────────────────────┐
│                    DETR 架构图                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   输入图像 (H × W × 3)                                      │
│         │                                                   │
│         ▼                                                   │
│   ┌─────────────────┐                                      │
│   │  CNN Backbone   │ ResNet-50/101                        │
│   │  (下采样4倍)     │ 输出: H/32 × W/32 × 2048             │
│   └────────┬────────┘                                      │
│            │                                               │
│            ▼                                               │
│   ┌─────────────────┐                                      │
│   │  Input Projection │ 1×1卷积降维                          │
│   │  (d=256)         │ 输出: H/32 × W/32 × 256             │
│   └────────┬────────┘                                      │
│            │                                               │
│            ▼                                               │
│   ┌─────────────────────────────────────────────────────┐  │
│   │              Transformer Encoder                     │  │
│   │   - 6层编码器                                        │  │
│   │   - 多头自注意力                                      │  │
│   │   - 前馈网络                                         │  │
│   └────────┬────────────────────────────────────────────┘  │
│            │                                               │
│            ▼                                               │
│   ┌─────────────────────────────────────────────────────┐  │
│   │              Transformer Decoder                     │  │
│   │   - 6层解码器                                        │  │
│   │   - N个对象查询 (N=100)                               │  │
│   │   - 交叉注意力 (查询←编码器)                           │  │
│   └────────┬────────────────────────────────────────────┘  │
│            │                                               │
│            ▼                                               │
│   ┌─────────────────────────────────────────────────────┐  │
│   │               预测头 (Prediction Heads)               │  │
│   │   - 分类: N × (C+1) 类别概率                          │  │
│   │   - 回归: N × 4 边界框                                │  │
│   └─────────────────────────────────────────────────────┘  │
│            │                                               │
│            ▼                                               │
│   ┌─────────────────────────────────────────────────────┐  │
│   │             集合预测损失 (Set Prediction Loss)        │  │
│   │   - Hungarian Matching                               │  │
│   │   - 分类损失 + L1损失 + GIoU损失                       │  │
│   └─────────────────────────────────────────────────────┘  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2.2 Backbone

class DETRBackbone(nn.Module):
    """CNN Backbone: ResNet-50/101"""
    
    def __init__(self, backbone='resnet50', pretrained=True):
        super().__init__()
        
        if backbone == 'resnet50':
            self.backbone = resnet50(pretrained=pretrained)
            self.num_features = 2048
        else:
            self.backbone = resnet101(pretrained=pretrained)
            self.num_features = 2048
        
    def forward(self, x):
        # ResNet特征提取
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        
        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)
        
        # 输出: [B, 2048, H/32, W/32]
        return x

2.3 Transformer编码器

class TransformerEncoder(nn.Module):
    """Transformer编码器"""
    
    def __init__(self, d_model=256, nhead=8, num_encoder_layers=6,
                 dim_feedforward=2048, dropout=0.1):
        super().__init__()
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu',
            batch_first=True
        )
        
        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_encoder_layers
        )
        
    def forward(self, src, mask=None):
        """
        Args:
            src: [B, C, H, W] 特征图
            mask: [B, H*W] 可选掩码
        """
        B, C, H, W = src.shape
        
        # 展平并转置: [B, H*W, C]
        src = src.flatten(2).permute(0, 2, 1)
        
        # 位置编码
        src = src + self.pos_encoding(src.shape)
        
        # 编码
        memory = self.encoder(src, src_key_padding_mask=mask)
        
        return memory  # [B, H*W, d_model]

2.4 Transformer解码器

class TransformerDecoder(nn.Module):
    """Transformer解码器: N个对象查询"""
    
    def __init__(self, d_model=256, nhead=8, num_decoder_layers=6,
                 dim_feedforward=2048, dropout=0.1, num_queries=100):
        super().__init__()
        
        self.num_queries = num_queries
        
        # 可学习的对象查询 (Query Embeddings)
        self.query_embed = nn.Embedding(num_queries, d_model)
        self.query_pos = nn.Embedding(num_queries, d_model)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu',
            batch_first=True
        )
        
        self.decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=num_decoder_layers
        )
        
    def forward(self, memory, target=None):
        """
        Args:
            memory: [B, H*W, d_model] 编码器输出
            target: 可选的decoder输入 (训练时使用)
        """
        B = memory.shape[0]
        
        # 初始化查询
        if target is None:
            # 推理时:从零开始
            tgt = torch.zeros(B, self.num_queries, memory.shape[-1], 
                            device=memory.device)
        else:
            tgt = target
            
        # 查询嵌入
        query = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)
        query_pos = self.query_pos.weight.unsqueeze(0).repeat(B, 1, 1)
        
        # 解码
        hs = self.decoder(
            tgt + query_pos,  # 添加位置编码
            memory,
            query_key_padding_mask=None,
            memory_key_padding_mask=None
        )
        
        # 输出: [B, num_queries, d_model]
        return hs

2.5 预测头与边界框预测

class MLP(nn.Module):
    """多层感知机预测头"""
    
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
        )
        
    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x
 
 
class DETR(nn.Module):
    """完整DETR模型"""
    
    def __init__(self, num_classes=91, num_queries=100, d_model=256):
        super().__init__()
        self.num_classes = num_classes
        self.num_queries = num_queries
        
        # Backbone
        self.backbone = DETRBackbone()
        
        # 输入投影
        self.input_proj = nn.Conv2d(2048, d_model, kernel_size=1)
        
        # 位置编码
        self.pos_encoder = PositionalEncoding(d_model, dropout=0.1)
        
        # Transformer
        self.transformer = Transformer(d_model=d_model)
        
        # 预测头
        self.class_embed = nn.Linear(d_model, num_classes + 1)  # +1 for no object
        self.bbox_embed = MLP(d_model, d_model, 4, num_layers=3)
        
        # 边界框初始化
        self._reset_parameters()
        
    def _reset_parameters(self):
        # 边界框预测初始化为小中心、大的高宽
        nn.init.constant_(self.bbox_embed.layers[-1].weight, 0)
        nn.init.constant_(self.bbox_embed.layers[-1].bias, 0)
        
    def forward(self, images):
        # Backbone
        features = self.backbone(images)  # [B, 2048, H/32, W/32]
        
        # 投影
        src = self.input_proj(features)  # [B, d_model, H/32, W/32]
        
        # 位置编码
        src = self.pos_encoder(src)
        
        # Transformer
        hs = self.transformer(src)  # [B, num_queries, d_model]
        
        # 预测
        outputs_class = self.class_embed(hs)  # [B, num_queries, num_classes+1]
        outputs_coord = self.bbox_embed(hs).sigmoid()  # [B, num_queries, 4]
        
        # 边界框坐标转换 (cx, cy, w, h) -> (x, y, w, h)
        # DETR输出的是相对于图像尺寸的归一化坐标
        
        return {
            'pred_logits': outputs_class,
            'pred_boxes': outputs_coord
        }

三、集合预测损失

3.1 Hungarian Matching

DETR使用Hungarian算法进行最优匹配:

def hungarian_matching(pred_logits, pred_boxes, target_boxes, target_labels):
    """
    Hungarian Matching: 找到预测与GT的最优匹配
    
    Args:
        pred_logits: [B, num_queries, num_classes+1]
        pred_boxes: [B, num_queries, 4]
        target_boxes: [num_GT, 4]
        target_labels: [num_GT]
    """
    B, num_queries = pred_logits.shape[:2]
    
    # 计算分类损失矩阵
    # [B, num_queries, num_classes+1] -> [B, num_queries, num_classes]
    pred_probs = pred_logits[..., :-1].softmax(-1)  # 去掉no-object类
    
    # 计算bbox损失矩阵
    # 使用L1损失和GIoU损失的加权和
    cost_bbox = torch.cdist(pred_boxes, target_boxes, p=1)
    cost_giou = -box_iou(pred_boxes, target_boxes).log()
    
    # 组合损失
    cost_class = -pred_probs[:, :, target_labels]
    
    # [B, num_queries, num_GT]
    cost_matrix = cost_class + cost_bbox + cost_giou
    
    # Hungarian算法
    indices = []
    for b in range(B):
        indices_b = linear_sum_assignment(cost_matrix[b].cpu().numpy())
        indices.append((torch.tensor(indices_b[0]), torch.tensor(indices_b[1])))
    
    return indices  # [(pred_idx, target_idx), ...]

3.2 损失函数

def set_criterion(pred_logits, pred_boxes, targets):
    """
    DETR损失函数: 分类损失 + 边界框损失
    """
    indices = hungarian_matching(pred_logits, pred_boxes, targets)
    
    idx = _get_src_permutation_idx(indices)
    
    # 分类损失: 交叉熵
    target_classes = torch.full_like(pred_logits, 0)
    target_classes_o = torch.cat([t["labels"] for t in targets])
    target_classes[idx] = target_classes_o
    
    loss_ce = F.cross_entropy(
        pred_logits.transpose(1, 2), 
        target_classes, 
        reduction='mean'
    )
    
    # L1损失
    loss_bbox = F.l1_loss(
        pred_boxes[idx], 
        torch.cat([t['boxes'] for t in targets], dim=0)[idx[1]], 
        reduction='mean'
    )
    
    # GIoU损失
    loss_giou = 1 - torch.diag(box_iou(
        pred_boxes[idx],
        torch.cat([t['boxes'] for t in targets], dim=0)[idx[1]]
    )).mean()
    
    # 总损失
    loss = loss_ce + 5 * loss_bbox + loss_giou
    
    return {
        'loss_ce': loss_ce,
        'loss_bbox': loss_bbox,
        'loss_giou': loss_giou,
        'loss': loss
    }

四、实验结果

4.1 COCO目标检测

方法BackboneAPAP50AP75ParamsFLOPs
Faster R-CNNResNet-5042.062.145.541M180G
Faster R-CNN+ResNet-10144.063.947.860M340G
DETRResNet-5042.062.444.241M86G
DETR-DC5ResNet-5043.363.145.941M187G
DETR-DC5ResNet-10144.964.747.760M360G

4.2 收敛曲线

┌─────────────────────────────────────────────────────────────┐
│              DETR vs Faster R-CNN 收敛对比                  │
│                                                             │
│  AP                                                         │
│  50 ┤     ┌──── Faster R-CNN                               │
│  45 ┤   ┌─┤    ┌──── DETR                                   │
│  40 ┤ ┌─┤ └──┐ └───┐                                       │
│  35 ┤ │ └───┐ └──┐                                         │
│  30 ┤ └─┘                                              │
│     └┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┤
│      0    50   100  150  200  250  300  400  500          │
│                        Epochs                               │
│                                                             │
│  DETR需要更多epoch收敛,但最终性能相当                       │
└─────────────────────────────────────────────────────────────┘

4.3 检测结果可视化

DETR的优势:

  1. 全局上下文:能检测大物体和罕见类别
  2. 无重复检测:不需要NMS
  3. 注意力可视化:可解释性强

五、关键洞察

5.1 集合预测的优势

  1. 避免NMS:直接输出唯一检测结果
  2. 避免锚框:不需要预设锚框
  3. 端到端:训练和推理一致

5.2 训练挑战

挑战解决方案
收敛慢长时间训练(500 epochs)
小物体检测差使用多尺度特征(DC5)
边界框预测使用GIoU损失

5.3 与传统方法的对比

方面Faster R-CNNDETR
流程两阶段端到端
后处理NMS
锚框需要不需要
注意力全局注意力
训练epoch12-36300-500

六、DETR的后续改进

6.1 Deformable DETR

针对DETR收敛慢的问题:

# Deformable DETR核心改进
class DeformableAttention(nn.Module):
    """可变形注意力:只关注参考点周围的采样点"""
    
    def __init__(self, d_model=256, n_heads=8, n_levels=4, n_points=4):
        super().__init__()
        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
        
    def forward(self, query, reference_points, input_flatten):
        # 计算采样偏移
        offsets = self.sampling_offsets(query)
        
        # 在参考点周围采样特定数量的点
        # 大幅减少计算量,加速收敛

6.2 其他改进方向

方法改进点
Conditional DETR改进注意力解耦,加速收敛
TSP-RCNN结合传统R-CNN
UP-DETR无监督预训练
DAB-DETR使用Anchor Boxes改进查询

七、参考论文