EfficientViT 高效视觉Transformer

EfficientViT是微软提出的一系列高效视觉Transformer架构,通过多尺度注意力融合和轻量级设计,在保持高精度的同时显著降低计算成本。本章详细介绍其设计原理、架构实现和实验结果。

一、设计动机与核心思想

1.1 现有ViT的问题

标准Vision Transformer存在三个主要问题:

问题描述影响
高计算复杂度MHSA复杂度为 难以处理高分辨率图像
弱局部建模Patch间缺乏局部交互训练不稳定、数据效率低
内存占用大注意力矩阵存储部署困难

1.2 EfficientViT的解决方案

┌─────────────────────────────────────────────────────────────┐
│                  EfficientViT 设计原则                      │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 多尺度注意力 (Multi-Scale Attention)                   │
│     - 减少注意力窗口内的token数量                            │
│     - 跨窗口信息交互                                         │
│                                                             │
│  2. 轻量级FFN                                               │
│     - 减少FFN中间维度                                        │
│     - 引入深度可分离卷积                                     │
│                                                             │
│  3. 硬件感知设计                                             │
│     - 考虑内存访问模式                                       │
│     - 优化计算访存比                                         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

二、核心模块

2.1 多尺度注意力(Multi-Scale Attention, MSA)

传统局部注意力的局限

标准局部注意力只考虑固定窗口内的token:

# 传统局部注意力:固定窗口
def local_attention(x, window_size=7):
    B, N, C = x.shape
    # 复杂度: O(N * window_size^2)
    # 每个token只与窗口内token交互

MSA设计

EfficientViT提出多尺度注意力,通过不同窗口大小的注意力头捕获不同范围的依赖:

class MultiScaleAttention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        
        # 将注意力头分为多组,每组使用不同窗口大小
        self.head_dim = dim // num_heads
        
        # 组1: 小窗口 (3x3) - 捕获细粒度局部特征
        self.small_window = 3
        self.qkv_small = nn.Linear(dim, dim)
        
        # 组2: 中窗口 (7x7) - 捕获中等范围依赖
        self.medium_window = 7
        self.qkv_medium = nn.Linear(dim, dim)
        
        # 组3: 大窗口 (H/2 x W/2) - 捕获全局信息
        self.qkv_large = nn.Linear(dim, dim)
        
        # 跨尺度融合
        self.fusion = nn.Linear(dim * 3, dim)
        
    def forward(self, x):
        B, H, W, C = x.shape
        N = H * W
        
        # 重新排列为token序列
        x = x.view(B, N, C)
        
        # 小窗口注意力
        attn_small = self.local_attention(x, self.small_window)
        
        # 中窗口注意力
        attn_medium = self.local_attention(x, self.medium_window)
        
        # 大窗口注意力 (全局池化)
        attn_large = self.global_attention(x)
        
        # 拼接融合
        fused = torch.cat([attn_small, attn_medium, attn_large], dim=-1)
        return self.fusion(fused)

数学推导

设输入特征为 ,多尺度注意力输出:

其中各尺度注意力定义为:

2.2 轻量级FFN(Lightweight FFN)

标准FFN结构:

# 标准FFN: 中间维度通常为4x
ffn = nn.Sequential(
    nn.Linear(dim, dim * 4),
    nn.GELU(),
    nn.Linear(dim * 4, dim)
)

EfficientViT的轻量级FFN:

class LightweightFFN(nn.Module):
    def __init__(self, dim, expand_ratio=2):
        super().__init__()
        # 减少中间维度
        hidden_dim = dim * expand_ratio
        
        # 使用深度可分离卷积增强局部建模
        self.pw_conv1 = nn.Linear(dim, hidden_dim)
        self.dw_conv = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1, groups=hidden_dim)
        self.pw_conv2 = nn.Linear(hidden_dim, dim)
        
    def forward(self, x):
        B, N, C = x.shape
        H = W = int(math.sqrt(N))
        
        x = self.pw_conv1(x)
        x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
        x = self.dw_conv(x)
        x = x.permute(0, 2, 3, 1).reshape(B, N, -1)
        x = self.pw_conv2(x)
        return x

2.3 倒置残差块(Inverted Residual Block)

class InvertedResidual(nn.Module):
    """EfficientViT使用倒置残差块"""
    def __init__(self, dim, expand_ratio=4):
        super().__init__()
        hidden_dim = dim * expand_ratio
        
        # 扩展 → 深度可分离卷积 → 投影
        self.conv1 = nn.Conv2d(dim, hidden_dim, 1)
        self.dw_conv = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1, groups=hidden_dim)
        self.conv2 = nn.Conv2d(hidden_dim, dim, 1)
        
    def forward(self, x):
        return x + self.conv2(self.dw_conv(self.conv1(x)))

三、完整架构

3.1 EfficientViT-M系列

配置通道数深度参数量FLOPsImageNet Top-1
M1[16, 32, 64, 128, 256][1, 2, 2, 2, 1]3.9M0.4G72.4%
M2[24, 48, 80, 160, 320][1, 2, 3, 2, 1]7.3M0.9G76.9%
M3[32, 64, 112, 224, 384][2, 3, 4, 3, 2]13.7M2.0G79.8%
M4[32, 64, 128, 256, 512][2, 4, 5, 4, 2]24.2M3.9G81.6%
M5[48, 96, 192, 384, 640][3, 5, 6, 5, 3]43.0M8.0G83.0%

3.2 层次化结构

EfficientViT-M3 架构:
┌─────────────────────────────────────────────────────────────┐
│ Stage 1: Patch Embedding                                   │
│ Input: 224x224x3 → Patch 16x16 → 196 tokens               │
│ Output: 196x32                                             │
├─────────────────────────────────────────────────────────────┤
│ Stage 2: EfficientViT Block ×1                            │
│ Channels: 32 → 48, Multi-Scale Attention                   │
│ Output: 196x48, H/8 x W/8                                  │
├─────────────────────────────────────────────────────────────┤
│ Stage 3: EfficientViT Block ×3                             │
│ Channels: 48 → 80, Deeper MSA                              │
│ Output: 49x80, H/16 x W/16                                 │
├─────────────────────────────────────────────────────────────┤
│ Stage 4: EfficientViT Block ×4                             │
│ Channels: 80 → 128, Lightweight FFN                        │
│ Output: 49x128, H/16 x W/16                                │
├─────────────────────────────────────────────────────────────┤
│ Stage 5: EfficientViT Block ×3                             │
│ Channels: 128 → 224, Global Aggregation                     │
│ Output: 224x224                                            │
├─────────────────────────────────────────────────────────────┤
│ Classification Head: Linear + GELU + Linear                 │
└─────────────────────────────────────────────────────────────┘

3.3 PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class EfficientViT(nn.Module):
    """EfficientViT-M3 完整实现"""
    
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
                 embed_dims=[48, 80, 128, 224], depths=[1, 3, 4, 3], num_heads=8):
        super().__init__()
        
        self.num_classes = num_classes
        self.num_features = embed_dims[-1]
        
        # Patch Embedding
        self.patch_embed = nn.Sequential(
            nn.Conv2d(in_chans, embed_dims[0]//2, 3, stride=2, padding=1),
            nn.BatchNorm2d(embed_dims[0]//2),
            nn.GELU(),
            nn.Conv2d(embed_dims[0]//2, embed_dims[0], 3, stride=2, padding=1),
            nn.BatchNorm2d(embed_dims[0]),
        )
        
        # 位置编码
        self.pos_embed = nn.Parameter(torch.zeros(1, (img_size // 16) ** 2, embed_dims[0]))
        
        # EfficientViT Blocks
        self.blocks = nn.ModuleList()
        for i, (dim, depth) in enumerate(zip(embed_dims, depths)):
            for j in range(depth):
                downsample = (i > 0 and j == 0)
                self.blocks.append(
                    EfficientViTBlock(dim, num_heads, downsample=downsample)
                )
        
        # 分类头
        self.head = nn.Sequential(
            nn.LayerNorm(embed_dims[-1]),
            nn.Linear(embed_dims[-1], num_classes)
        )
        
        self._init_weights()
        
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
                    
    def forward(self, x):
        x = self.patch_embed(x)  # B, C, H/4, W/4
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # B, N, C
        x = x + self.pos_embed
        
        for block in self.blocks:
            x = block(x, H, W)
            
        x = x.mean(dim=1)  # 全局平均池化
        return self.head(x)
 
 
class EfficientViTBlock(nn.Module):
    """EfficientViT核心模块"""
    
    def __init__(self, dim, num_heads=8, window_size=7, downsample=False):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.downsample = downsample
        
        # 多尺度注意力
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiScaleAttention(dim, num_heads)
        
        # 倒置残差块
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = LightweightFFN(dim)
        
        # 下采样(可选)
        if downsample:
            self.downsample = nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * 2),
                nn.GELU(),
                nn.Linear(dim * 2, dim * 2),
            )
        else:
            self.downsample = None
            
    def forward(self, x, H, W):
        # 多尺度注意力 + 残差
        x = x + self.attn(self.norm1(x), H, W)
        
        # FFN + 残差
        x = x + self.ffn(self.norm2(x))
        
        # 下采样
        if self.downsample is not None:
            B, N, C = x.shape
            x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
            x = F.avg_pool2d(x, 2)
            x = x.permute(0, 2, 3, 1).reshape(B, -1, C)
            
        return x

四、实验结果

4.1 ImageNet分类

模型参数量FLOPsTop-1Throughput (img/s)
DeiT-S22M4.6G79.8%125
EfficientViT-M37.3M0.9G76.9%520
MobileViT-S6.4M2.0G78.1%310
EfficientViT-M424.2M3.9G81.6%215
Swin-T28M4.5G81.3%118

4.2 目标检测(COCO)

BackboneAPAP50AP75参数量
DeiT-S43.265.446.822M
EfficientViT-M444.566.848.224M
Swin-T44.566.448.128M

4.3 语义分割(ADE20K)

BackbonemIoU参数量FLOPs
DeiT-S45.822M5.7G
EfficientViT-M447.224M5.9G
Swin-T48.128M7.3G

五、关键洞察

5.1 设计原则总结

  1. 多尺度融合:不同窗口大小的注意力捕获不同范围的特征,提高表达能力
  2. 轻量级FFN:减少参数量的同时保持建模能力
  3. 硬件感知:优化计算访存比,提高实际部署效率

5.2 与其他轻量级ViT的对比

特性EfficientViTMobileViTTwins
注意力机制多尺度MSAMobile块+MHSALSA+GSA
FFN设计轻量级+DWConvMobile块标准
位置编码可学习固定CPE
适用场景高效部署移动端平衡性能

5.3 局限性

  • 多尺度注意力增加了实现复杂度
  • 在某些任务上性能略低于Swin-T
  • 训练收敛需要更多epoch

六、参考论文