MobileViT 轻量级视觉Transformer

MobileViT是苹果公司提出的轻量级视觉Transformer架构,旨在将Transformer的表达能力与MobileNet的轻量特性结合,实现移动端友好的视觉模型。本章详细介绍MobileViT的设计原理、核心模块和实验结果。

一、设计动机

1.1 ViT与MobileNet的对比

特性Vision Transformer (ViT)MobileNet
核心操作自注意力深度可分离卷积
局部建模弱(需要更多数据)强(卷积天然局部)
全局建模强(注意力机制)弱(需要堆叠更多层)
参数量较大较小
计算量较高较低
移动友好度

1.2 MobileViT的核心思想

┌─────────────────────────────────────────────────────────────┐
│                  MobileViT 设计理念                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   MobileNet: 强局部建模 ──────────────────────────┐         │
│                      │                              │         │
│                      ▼                              ▼         │
│         深度可分离卷积          Transformer块        │         │
│              │                    │                 │         │
│              └──────────┬─────────┘                 │         │
│                         ▼                           │         │
│                   MobileViT块                       │         │
│                         │                           │         │
│                         ▼                           │         │
│              局部 + 全局 表示学习                    │         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

二、MobileViT核心模块

2.1 MobileViT Block

MobileViT Block是模型的核心组件,结合了卷积和Transformer的优点:

class MobileViTBlock(nn.Module):
    """
    MobileViT Block = 局部表示 (卷积) + 全局表示 (Transformer)
    
    输入: [B, C, H, W]
    输出: [B, C, H, W]
    """
    
    def __init__(self, in_channels, out_channels, kernel_size=3,
                 patch_size=2, d_model=96, transformer_depth=3, 
                 num_heads=4, mlp_ratio=4.0):
        super().__init__()
        self.patch_size = patch_size
        
        # 1. 局部表示学习:n×n深度可分离卷积
        # 将通道维度扩展到 d_model
        self.local_rep = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size, 
                     padding=kernel_size//2, groups=in_channels),  # Depthwise
            nn.Conv2d(in_channels, d_model, 1),  # Pointwise
        )
        
        # 2. 全局表示学习:Transformer
        # 将空间维度转换为patch维度
        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.fold = nn.Fold(
            output_size=(H, W), 
            kernel_size=patch_size, 
            stride=patch_size
        )
        
        self.transformer = TransformerEncoder(
            d_model=d_model,
            depth=transformer_depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio
        )
        
        # 3. 融合:1×1卷积将d_model转回原始通道数
        self.fusion = nn.Sequential(
            nn.Conv2d(d_model, in_channels, 1),  # Pointwise
            nn.Conv2d(in_channels, out_channels, 1) if out_channels != in_channels else nn.Identity(),
        )
        
        # 残差连接(仅当输入输出通道相同时)
        self.use_residual = in_channels == out_channels
        
    def forward(self, x):
        B, C, H, W = x.shape
        
        # Step 1: 局部表示
        h_local = self.local_rep(x)  # [B, d_model, H, W]
        
        # Step 2: 转换为patch表示
        # [B, d_model, H, W] -> [B, d_model, num_patches_h, num_patches_w, p_h, p_w]
        h_unfold = self.unfold(h_local)  # [B, d_model * patch_size^2, num_patches]
        num_patches = h_unfold.shape[-1]
        p_h = p_w = self.patch_size
        
        h_unfold = h_unfold.reshape(B, d_model, p_h * p_w, num_patches)
        h_unfold = h_unfold.permute(0, 3, 1, 2)  # [B, num_patches, d_model, p_h*p_w]
        
        # 展平patch维度到序列
        h_unfold = h_unfold.reshape(B * num_patches, d_model, p_h * p_w)
        h_unfold = h_unfold.permute(0, 2, 1)  # [B*num_patches, p_h*p_w, d_model]
        
        # Step 3: 全局Transformer建模
        h_global = self.transformer(h_global)  # [B*num_patches, p_h*p_w, d_model]
        
        # Step 4: 恢复空间结构
        h_global = h_global.permute(0, 2, 1)  # [B*num_patches, d_model, p_h*p_w]
        h_global = h_global.reshape(B, num_patches, d_model, p_h, p_w)
        h_global = h_global.permute(0, 1, 3, 4, 2)  # [B, num_patches, p_h, p_w, d_model]
        
        # Fold back
        h_global = h_global.reshape(B, num_patches, d_model * p_h * p_w)
        h_global = h_global.permute(0, 2, 1)  # [B, d_model*p_h*p_w, num_patches]
        
        # 重排回 [B, d_model, H, W]
        h_global = F.fold(
            h_global, 
            output_size=(H, W), 
            kernel_size=p_h, 
            stride=p_h
        )
        
        # Step 5: 融合
        h = self.fusion(h_global)
        
        # 残差连接
        if self.use_residual:
            h = h + x
            
        return h

2.2 简化的MobileViT Block

上述实现较复杂,实际代码中通常使用更简洁的版本:

class MobileViTBlockV2(nn.Module):
    """简化版MobileViT Block"""
    
    def __init__(self, in_channels, out_channels, patch_size=2, d_model=96):
        super().__init__()
        self.patch_size = patch_size
        self.d_model = d_model
        
        # 局部卷积
        self.conv3x3 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels),
            nn.BatchNorm2d(in_channels),
            nn.ReLU6(),
            nn.Conv2d(in_channels, d_model, 1),
        )
        
        # Transformer
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=4,
            num_encoder_layers=3,
            batch_first=True
        )
        
        # 输出投影
        self.proj = nn.Conv2d(d_model, out_channels, 1)
        
        # 位置编码
        self.pos_embed = nn.Parameter(torch.randn(1, (224 // patch_size) ** 2, d_model))
        
    def forward(self, x):
        B, C, H, W = x.shape
        ph = pw = self.patch_size
        
        # 局部卷积
        x = self.conv3x3(x)  # [B, d_model, H, W]
        
        # 转换为序列
        x = x.reshape(B, self.d_model, H // ph, ph, W // pw, pw)
        x = x.permute(0, 2, 4, 1, 3, 5)  # [B, H/p, W/p, d_model, p_h, p_w]
        x = x.reshape(B * (H // ph) * (W // pw), ph * pw, self.d_model)
        
        # Transformer
        x = x + self.pos_embed
        x = self.transformer(x, x)
        
        # 恢复空间结构
        x = x.reshape(B, H // ph, W // pw, ph, pw, self.d_model)
        x = x.permute(0, 3, 5, 1, 2, 4)  # [B, p_h, p_w, d_model, H/p, W/p]
        x = x.reshape(B, ph * pw, self.d_model, H // ph, W // pw)
        x = x.permute(0, 3, 4, 2, 1)  # [B, H/p, W/p, d_model, p_h*p_w]
        x = x.reshape(B, self.d_model, H, W)
        
        # 投影并返回
        return self.proj(x)

三、MobileViT架构

3.1 MobileViT-Small配置

┌─────────────────────────────────────────────────────────────┐
│                 MobileViT-S 架构                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  输入: 256×256×3                                            │
│         │                                                   │
│         ▼                                                   │
│  Stage 0: Conv 3×3, stride=2                               │
│  Output: 128×128×32                                         │
│         │                                                   │
│         ▼                                                   │
│  Stage 1: MV2 Block ×2                                     │
│           └→ 128×128×32 → 128×128×64 (stride=1)           │
│  Output: 128×128×64                                         │
│         │                                                   │
│         ▼                                                   │
│  Stage 2: MV2 Block ×1 + MobileViT Block ×1               │
│           └→ 128×128×64 → 64×64×96                         │
│  Output: 64×64×96                                           │
│         │                                                   │
│         ▼                                                   │
│  Stage 3: MV2 Block ×1 + MobileViT Block ×1               │
│           └→ 64×64×96 → 32×32×128                          │
│  Output: 32×32×128                                          │
│         │                                                   │
│         ▼                                                   │
│  Stage 4: MV2 Block ×1 + MobileViT Block ×1               │
│           └→ 32×32×128 → 16×16×192                          │
│  Output: 16×16×192                                          │
│         │                                                   │
│         ▼                                                   │
│  Stage 5: MV2 Block ×1                                     │
│  Output: 16×16×192 → 16×16×320 (expansion)                 │
│         │                                                   │
│         ▼                                                   │
│  全局池化 + Linear + Softmax                               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

3.2 完整实现

class MobileViT(nn.Module):
    """MobileViT完整模型"""
    
    def __init__(self, num_classes=1000):
        super().__init__()
        
        # Stage 0: 初始卷积
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU6(),
        )
        
        # Stage 1
        self.stage1 = nn.Sequential(
            InvertedResidual(32, 64, stride=1, expand_ratio=4),
            InvertedResidual(64, 64, stride=1, expand_ratio=4),
        )
        
        # Stage 2: 第一个MobileViT
        self.stage2 = nn.Sequential(
            InvertedResidual(64, 96, stride=2, expand_ratio=4),
            MobileViTBlock(96, 96, patch_size=2, d_model=96, 
                          transformer_depth=2, num_heads=4),
        )
        
        # Stage 3
        self.stage3 = nn.Sequential(
            InvertedResidual(96, 128, stride=2, expand_ratio=4),
            MobileViTBlock(128, 128, patch_size=2, d_model=144,
                          transformer_depth=3, num_heads=4),
        )
        
        # Stage 4
        self.stage4 = nn.Sequential(
            InvertedResidual(128, 192, stride=2, expand_ratio=4),
            MobileViTBlock(192, 192, patch_size=2, d_model=192,
                          transformer_depth=3, num_heads=4),
        )
        
        # Stage 5
        self.stage5 = nn.Sequential(
            InvertedResidual(192, 320, stride=1, expand_ratio=4),
        )
        
        # 分类头
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(320, num_classes)
        )
        
    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = self.head(x)
        return x

四、实验结果

4.1 ImageNet分类

模型参数量MACsTop-1Top-5
MobileNetV23.4M300M72.0%-
MobileNetV3-S5.4M219M75.2%-
MobileViT-S5.6M1.9G78.4%94.0%
MobileViT-XS3.5M0.9G74.8%-
EfficientNet-B05.3M390M77.1%-
DeiT-T5.7M1.3G72.2%91.1%

4.2 移动端推理速度

模型iPhone 12 (ms)Pixel 6 (ms)Snapdragon 865 (ms)
MobileNetV3-S8.212.110.5
MobileViT-S10.114.812.2
EfficientNet-B025.332.128.5

4.3 目标检测性能

BackboneDetectormAPParams
MobileNetV3YOLOv4-Tiny24.35.8M
MobileViT-SYOLOv4-Tiny28.18.2M
MobileNetV2RetinaNet31.48.5M
MobileViT-SRetinaNet34.811.2M

五、与其他轻量级ViT的对比

5.1 架构对比

特性MobileViTEfficientViTTwins
局部建模MV2+卷积多尺度注意力LSA
全局建模TransformerTransformerGSA
参数共享
位置编码可学习可学习CPE
适用场景移动端高效推理平衡性能

5.2 MobileViT的优势

  1. 轻量级:专为移动端设计,参数量和计算量都较小
  2. 高性能:在ImageNet上达到78%+的准确率
  3. 通用性:可用于分类、检测、分割等多种任务
  4. 简单性:架构相对简单,易于实现和部署

5.3 MobileViT的局限性

  1. Transformer深度有限:移动端资源限制导致不能堆叠太多层
  2. Patch大小固定:Patch大小选择影响性能和效率的权衡
  3. 位置编码:简单的位置编码可能不如复杂的位置编码有效

六、关键洞察

6.1 局部-全局表示学习的重要性

MobileViT的核心洞察是:好的视觉表示需要同时捕获局部和全局信息

  • 局部表示:由卷积捕获,适用于边缘、纹理等低级特征
  • 全局表示:由Transformer捕获,适用于物体结构、关系等高级特征

6.2 移动端设计的权衡

┌─────────────────────────────────────────────────────────────┐
│               移动端设计的核心权衡                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   参数量 ◄──────────────────────────────────────► 精度      │
│     │                                               │       │
│     │                                               │       │
│     │    MobileViT                                  │       │
│     │         ◄──────── 平衡点 ───────►             │       │
│     │                                               │       │
│     ▼                                               ▼       │
│   MobileNet                          大ViT (ViT-L, Swin)   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

6.3 与知识蒸馏的结合

MobileViT常与知识蒸馏结合使用:

class DistillationLoss(nn.Module):
    """MobileViT知识蒸馏损失"""
    
    def __init__(self, teacher, alpha=0.5, temperature=3.0):
        super().__init__()
        self.teacher = teacher
        self.alpha = alpha
        self.temperature = temperature
        
    def forward(self, student_logits, targets):
        # Hard loss
        hard_loss = F.cross_entropy(student_logits, targets)
        
        # Soft loss (知识蒸馏)
        with torch.no_grad():
            teacher_logits = self.teacher(inputs)
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / self.temperature, dim=-1),
            F.softmax(teacher_logits / self.temperature, dim=-1),
            reduction='batchmean'
        ) * (self.temperature ** 2)
        
        return self.alpha * hard_loss + (1 - self.alpha) * soft_loss

七、MobileViT v2改进

MobileViT v2进一步优化了架构:

  1. 移除注意力中的残差:提高计算效率
  2. 线性复杂度的自注意力:使用线性注意力替代标准注意力
  3. 更高效的特征融合:改进上采样和下采样
class MobileViTv2Block(nn.Module):
    """MobileViT v2核心模块"""
    
    def __init__(self, in_channels, out_channels, d_model=128):
        super().__init__()
        
        # 线性复杂度注意力
        self.linear_attention = LinearAttention(d_model)
        
        # 简化的投影
        self.proj_in = nn.Conv2d(in_channels, d_model, 1)
        self.proj_out = nn.Conv2d(d_model, out_channels, 1)
        
    def forward(self, x):
        x = self.proj_in(x)
        B, C, H, W = x.shape
        
        # 线性注意力
        x = x.reshape(B, C, -1).transpose(1, 2)  # [B, HW, C]
        x = self.linear_attention(x)
        x = x.transpose(1, 2).reshape(B, C, H, W)
        
        return self.proj_out(x)

八、参考论文