EfficientViT 高效视觉Transformer
EfficientViT是微软提出的一系列高效视觉Transformer架构,通过多尺度注意力融合和轻量级设计,在保持高精度的同时显著降低计算成本。本章详细介绍其设计原理、架构实现和实验结果。
一、设计动机与核心思想
1.1 现有ViT的问题
标准Vision Transformer存在三个主要问题:
| 问题 | 描述 | 影响 |
|---|---|---|
| 高计算复杂度 | MHSA复杂度为 | 难以处理高分辨率图像 |
| 弱局部建模 | Patch间缺乏局部交互 | 训练不稳定、数据效率低 |
| 内存占用大 | 注意力矩阵存储 | 部署困难 |
1.2 EfficientViT的解决方案
┌─────────────────────────────────────────────────────────────┐
│ EfficientViT 设计原则 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 多尺度注意力 (Multi-Scale Attention) │
│ - 减少注意力窗口内的token数量 │
│ - 跨窗口信息交互 │
│ │
│ 2. 轻量级FFN │
│ - 减少FFN中间维度 │
│ - 引入深度可分离卷积 │
│ │
│ 3. 硬件感知设计 │
│ - 考虑内存访问模式 │
│ - 优化计算访存比 │
│ │
└─────────────────────────────────────────────────────────────┘
二、核心模块
2.1 多尺度注意力(Multi-Scale Attention, MSA)
传统局部注意力的局限
标准局部注意力只考虑固定窗口内的token:
# 传统局部注意力:固定窗口
def local_attention(x, window_size=7):
B, N, C = x.shape
# 复杂度: O(N * window_size^2)
# 每个token只与窗口内token交互MSA设计
EfficientViT提出多尺度注意力,通过不同窗口大小的注意力头捕获不同范围的依赖:
class MultiScaleAttention(nn.Module):
def __init__(self, dim, num_heads=8):
super().__init__()
self.num_heads = num_heads
# 将注意力头分为多组,每组使用不同窗口大小
self.head_dim = dim // num_heads
# 组1: 小窗口 (3x3) - 捕获细粒度局部特征
self.small_window = 3
self.qkv_small = nn.Linear(dim, dim)
# 组2: 中窗口 (7x7) - 捕获中等范围依赖
self.medium_window = 7
self.qkv_medium = nn.Linear(dim, dim)
# 组3: 大窗口 (H/2 x W/2) - 捕获全局信息
self.qkv_large = nn.Linear(dim, dim)
# 跨尺度融合
self.fusion = nn.Linear(dim * 3, dim)
def forward(self, x):
B, H, W, C = x.shape
N = H * W
# 重新排列为token序列
x = x.view(B, N, C)
# 小窗口注意力
attn_small = self.local_attention(x, self.small_window)
# 中窗口注意力
attn_medium = self.local_attention(x, self.medium_window)
# 大窗口注意力 (全局池化)
attn_large = self.global_attention(x)
# 拼接融合
fused = torch.cat([attn_small, attn_medium, attn_large], dim=-1)
return self.fusion(fused)数学推导
设输入特征为 ,多尺度注意力输出:
其中各尺度注意力定义为:
2.2 轻量级FFN(Lightweight FFN)
标准FFN结构:
# 标准FFN: 中间维度通常为4x
ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)EfficientViT的轻量级FFN:
class LightweightFFN(nn.Module):
def __init__(self, dim, expand_ratio=2):
super().__init__()
# 减少中间维度
hidden_dim = dim * expand_ratio
# 使用深度可分离卷积增强局部建模
self.pw_conv1 = nn.Linear(dim, hidden_dim)
self.dw_conv = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1, groups=hidden_dim)
self.pw_conv2 = nn.Linear(hidden_dim, dim)
def forward(self, x):
B, N, C = x.shape
H = W = int(math.sqrt(N))
x = self.pw_conv1(x)
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
x = self.dw_conv(x)
x = x.permute(0, 2, 3, 1).reshape(B, N, -1)
x = self.pw_conv2(x)
return x2.3 倒置残差块(Inverted Residual Block)
class InvertedResidual(nn.Module):
"""EfficientViT使用倒置残差块"""
def __init__(self, dim, expand_ratio=4):
super().__init__()
hidden_dim = dim * expand_ratio
# 扩展 → 深度可分离卷积 → 投影
self.conv1 = nn.Conv2d(dim, hidden_dim, 1)
self.dw_conv = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1, groups=hidden_dim)
self.conv2 = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
return x + self.conv2(self.dw_conv(self.conv1(x)))三、完整架构
3.1 EfficientViT-M系列
| 配置 | 通道数 | 深度 | 参数量 | FLOPs | ImageNet Top-1 |
|---|---|---|---|---|---|
| M1 | [16, 32, 64, 128, 256] | [1, 2, 2, 2, 1] | 3.9M | 0.4G | 72.4% |
| M2 | [24, 48, 80, 160, 320] | [1, 2, 3, 2, 1] | 7.3M | 0.9G | 76.9% |
| M3 | [32, 64, 112, 224, 384] | [2, 3, 4, 3, 2] | 13.7M | 2.0G | 79.8% |
| M4 | [32, 64, 128, 256, 512] | [2, 4, 5, 4, 2] | 24.2M | 3.9G | 81.6% |
| M5 | [48, 96, 192, 384, 640] | [3, 5, 6, 5, 3] | 43.0M | 8.0G | 83.0% |
3.2 层次化结构
EfficientViT-M3 架构:
┌─────────────────────────────────────────────────────────────┐
│ Stage 1: Patch Embedding │
│ Input: 224x224x3 → Patch 16x16 → 196 tokens │
│ Output: 196x32 │
├─────────────────────────────────────────────────────────────┤
│ Stage 2: EfficientViT Block ×1 │
│ Channels: 32 → 48, Multi-Scale Attention │
│ Output: 196x48, H/8 x W/8 │
├─────────────────────────────────────────────────────────────┤
│ Stage 3: EfficientViT Block ×3 │
│ Channels: 48 → 80, Deeper MSA │
│ Output: 49x80, H/16 x W/16 │
├─────────────────────────────────────────────────────────────┤
│ Stage 4: EfficientViT Block ×4 │
│ Channels: 80 → 128, Lightweight FFN │
│ Output: 49x128, H/16 x W/16 │
├─────────────────────────────────────────────────────────────┤
│ Stage 5: EfficientViT Block ×3 │
│ Channels: 128 → 224, Global Aggregation │
│ Output: 224x224 │
├─────────────────────────────────────────────────────────────┤
│ Classification Head: Linear + GELU + Linear │
└─────────────────────────────────────────────────────────────┘
3.3 PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class EfficientViT(nn.Module):
"""EfficientViT-M3 完整实现"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
embed_dims=[48, 80, 128, 224], depths=[1, 3, 4, 3], num_heads=8):
super().__init__()
self.num_classes = num_classes
self.num_features = embed_dims[-1]
# Patch Embedding
self.patch_embed = nn.Sequential(
nn.Conv2d(in_chans, embed_dims[0]//2, 3, stride=2, padding=1),
nn.BatchNorm2d(embed_dims[0]//2),
nn.GELU(),
nn.Conv2d(embed_dims[0]//2, embed_dims[0], 3, stride=2, padding=1),
nn.BatchNorm2d(embed_dims[0]),
)
# 位置编码
self.pos_embed = nn.Parameter(torch.zeros(1, (img_size // 16) ** 2, embed_dims[0]))
# EfficientViT Blocks
self.blocks = nn.ModuleList()
for i, (dim, depth) in enumerate(zip(embed_dims, depths)):
for j in range(depth):
downsample = (i > 0 and j == 0)
self.blocks.append(
EfficientViTBlock(dim, num_heads, downsample=downsample)
)
# 分类头
self.head = nn.Sequential(
nn.LayerNorm(embed_dims[-1]),
nn.Linear(embed_dims[-1], num_classes)
)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.patch_embed(x) # B, C, H/4, W/4
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # B, N, C
x = x + self.pos_embed
for block in self.blocks:
x = block(x, H, W)
x = x.mean(dim=1) # 全局平均池化
return self.head(x)
class EfficientViTBlock(nn.Module):
"""EfficientViT核心模块"""
def __init__(self, dim, num_heads=8, window_size=7, downsample=False):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.downsample = downsample
# 多尺度注意力
self.norm1 = nn.LayerNorm(dim)
self.attn = MultiScaleAttention(dim, num_heads)
# 倒置残差块
self.norm2 = nn.LayerNorm(dim)
self.ffn = LightweightFFN(dim)
# 下采样(可选)
if downsample:
self.downsample = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * 2),
nn.GELU(),
nn.Linear(dim * 2, dim * 2),
)
else:
self.downsample = None
def forward(self, x, H, W):
# 多尺度注意力 + 残差
x = x + self.attn(self.norm1(x), H, W)
# FFN + 残差
x = x + self.ffn(self.norm2(x))
# 下采样
if self.downsample is not None:
B, N, C = x.shape
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
x = F.avg_pool2d(x, 2)
x = x.permute(0, 2, 3, 1).reshape(B, -1, C)
return x四、实验结果
4.1 ImageNet分类
| 模型 | 参数量 | FLOPs | Top-1 | Throughput (img/s) |
|---|---|---|---|---|
| DeiT-S | 22M | 4.6G | 79.8% | 125 |
| EfficientViT-M3 | 7.3M | 0.9G | 76.9% | 520 |
| MobileViT-S | 6.4M | 2.0G | 78.1% | 310 |
| EfficientViT-M4 | 24.2M | 3.9G | 81.6% | 215 |
| Swin-T | 28M | 4.5G | 81.3% | 118 |
4.2 目标检测(COCO)
| Backbone | AP | AP50 | AP75 | 参数量 |
|---|---|---|---|---|
| DeiT-S | 43.2 | 65.4 | 46.8 | 22M |
| EfficientViT-M4 | 44.5 | 66.8 | 48.2 | 24M |
| Swin-T | 44.5 | 66.4 | 48.1 | 28M |
4.3 语义分割(ADE20K)
| Backbone | mIoU | 参数量 | FLOPs |
|---|---|---|---|
| DeiT-S | 45.8 | 22M | 5.7G |
| EfficientViT-M4 | 47.2 | 24M | 5.9G |
| Swin-T | 48.1 | 28M | 7.3G |
五、关键洞察
5.1 设计原则总结
- 多尺度融合:不同窗口大小的注意力捕获不同范围的特征,提高表达能力
- 轻量级FFN:减少参数量的同时保持建模能力
- 硬件感知:优化计算访存比,提高实际部署效率
5.2 与其他轻量级ViT的对比
| 特性 | EfficientViT | MobileViT | Twins |
|---|---|---|---|
| 注意力机制 | 多尺度MSA | Mobile块+MHSA | LSA+GSA |
| FFN设计 | 轻量级+DWConv | Mobile块 | 标准 |
| 位置编码 | 可学习 | 固定 | CPE |
| 适用场景 | 高效部署 | 移动端 | 平衡性能 |
5.3 局限性
- 多尺度注意力增加了实现复杂度
- 在某些任务上性能略低于Swin-T
- 训练收敛需要更多epoch