MobileViT 轻量级视觉Transformer
MobileViT是苹果公司提出的轻量级视觉Transformer架构,旨在将Transformer的表达能力与MobileNet的轻量特性结合,实现移动端友好的视觉模型。本章详细介绍MobileViT的设计原理、核心模块和实验结果。
一、设计动机
1.1 ViT与MobileNet的对比
| 特性 | Vision Transformer (ViT) | MobileNet |
|---|---|---|
| 核心操作 | 自注意力 | 深度可分离卷积 |
| 局部建模 | 弱(需要更多数据) | 强(卷积天然局部) |
| 全局建模 | 强(注意力机制) | 弱(需要堆叠更多层) |
| 参数量 | 较大 | 较小 |
| 计算量 | 较高 | 较低 |
| 移动友好度 | 低 | 高 |
1.2 MobileViT的核心思想
┌─────────────────────────────────────────────────────────────┐
│ MobileViT 设计理念 │
├─────────────────────────────────────────────────────────────┤
│ │
│ MobileNet: 强局部建模 ──────────────────────────┐ │
│ │ │ │
│ ▼ ▼ │
│ 深度可分离卷积 Transformer块 │ │
│ │ │ │ │
│ └──────────┬─────────┘ │ │
│ ▼ │ │
│ MobileViT块 │ │
│ │ │ │
│ ▼ │ │
│ 局部 + 全局 表示学习 │ │
│ │
└─────────────────────────────────────────────────────────────┘
二、MobileViT核心模块
2.1 MobileViT Block
MobileViT Block是模型的核心组件,结合了卷积和Transformer的优点:
class MobileViTBlock(nn.Module):
"""
MobileViT Block = 局部表示 (卷积) + 全局表示 (Transformer)
输入: [B, C, H, W]
输出: [B, C, H, W]
"""
def __init__(self, in_channels, out_channels, kernel_size=3,
patch_size=2, d_model=96, transformer_depth=3,
num_heads=4, mlp_ratio=4.0):
super().__init__()
self.patch_size = patch_size
# 1. 局部表示学习:n×n深度可分离卷积
# 将通道维度扩展到 d_model
self.local_rep = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size,
padding=kernel_size//2, groups=in_channels), # Depthwise
nn.Conv2d(in_channels, d_model, 1), # Pointwise
)
# 2. 全局表示学习:Transformer
# 将空间维度转换为patch维度
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
self.fold = nn.Fold(
output_size=(H, W),
kernel_size=patch_size,
stride=patch_size
)
self.transformer = TransformerEncoder(
d_model=d_model,
depth=transformer_depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio
)
# 3. 融合:1×1卷积将d_model转回原始通道数
self.fusion = nn.Sequential(
nn.Conv2d(d_model, in_channels, 1), # Pointwise
nn.Conv2d(in_channels, out_channels, 1) if out_channels != in_channels else nn.Identity(),
)
# 残差连接(仅当输入输出通道相同时)
self.use_residual = in_channels == out_channels
def forward(self, x):
B, C, H, W = x.shape
# Step 1: 局部表示
h_local = self.local_rep(x) # [B, d_model, H, W]
# Step 2: 转换为patch表示
# [B, d_model, H, W] -> [B, d_model, num_patches_h, num_patches_w, p_h, p_w]
h_unfold = self.unfold(h_local) # [B, d_model * patch_size^2, num_patches]
num_patches = h_unfold.shape[-1]
p_h = p_w = self.patch_size
h_unfold = h_unfold.reshape(B, d_model, p_h * p_w, num_patches)
h_unfold = h_unfold.permute(0, 3, 1, 2) # [B, num_patches, d_model, p_h*p_w]
# 展平patch维度到序列
h_unfold = h_unfold.reshape(B * num_patches, d_model, p_h * p_w)
h_unfold = h_unfold.permute(0, 2, 1) # [B*num_patches, p_h*p_w, d_model]
# Step 3: 全局Transformer建模
h_global = self.transformer(h_global) # [B*num_patches, p_h*p_w, d_model]
# Step 4: 恢复空间结构
h_global = h_global.permute(0, 2, 1) # [B*num_patches, d_model, p_h*p_w]
h_global = h_global.reshape(B, num_patches, d_model, p_h, p_w)
h_global = h_global.permute(0, 1, 3, 4, 2) # [B, num_patches, p_h, p_w, d_model]
# Fold back
h_global = h_global.reshape(B, num_patches, d_model * p_h * p_w)
h_global = h_global.permute(0, 2, 1) # [B, d_model*p_h*p_w, num_patches]
# 重排回 [B, d_model, H, W]
h_global = F.fold(
h_global,
output_size=(H, W),
kernel_size=p_h,
stride=p_h
)
# Step 5: 融合
h = self.fusion(h_global)
# 残差连接
if self.use_residual:
h = h + x
return h2.2 简化的MobileViT Block
上述实现较复杂,实际代码中通常使用更简洁的版本:
class MobileViTBlockV2(nn.Module):
"""简化版MobileViT Block"""
def __init__(self, in_channels, out_channels, patch_size=2, d_model=96):
super().__init__()
self.patch_size = patch_size
self.d_model = d_model
# 局部卷积
self.conv3x3 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels),
nn.BatchNorm2d(in_channels),
nn.ReLU6(),
nn.Conv2d(in_channels, d_model, 1),
)
# Transformer
self.transformer = nn.Transformer(
d_model=d_model,
nhead=4,
num_encoder_layers=3,
batch_first=True
)
# 输出投影
self.proj = nn.Conv2d(d_model, out_channels, 1)
# 位置编码
self.pos_embed = nn.Parameter(torch.randn(1, (224 // patch_size) ** 2, d_model))
def forward(self, x):
B, C, H, W = x.shape
ph = pw = self.patch_size
# 局部卷积
x = self.conv3x3(x) # [B, d_model, H, W]
# 转换为序列
x = x.reshape(B, self.d_model, H // ph, ph, W // pw, pw)
x = x.permute(0, 2, 4, 1, 3, 5) # [B, H/p, W/p, d_model, p_h, p_w]
x = x.reshape(B * (H // ph) * (W // pw), ph * pw, self.d_model)
# Transformer
x = x + self.pos_embed
x = self.transformer(x, x)
# 恢复空间结构
x = x.reshape(B, H // ph, W // pw, ph, pw, self.d_model)
x = x.permute(0, 3, 5, 1, 2, 4) # [B, p_h, p_w, d_model, H/p, W/p]
x = x.reshape(B, ph * pw, self.d_model, H // ph, W // pw)
x = x.permute(0, 3, 4, 2, 1) # [B, H/p, W/p, d_model, p_h*p_w]
x = x.reshape(B, self.d_model, H, W)
# 投影并返回
return self.proj(x)三、MobileViT架构
3.1 MobileViT-Small配置
┌─────────────────────────────────────────────────────────────┐
│ MobileViT-S 架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 输入: 256×256×3 │
│ │ │
│ ▼ │
│ Stage 0: Conv 3×3, stride=2 │
│ Output: 128×128×32 │
│ │ │
│ ▼ │
│ Stage 1: MV2 Block ×2 │
│ └→ 128×128×32 → 128×128×64 (stride=1) │
│ Output: 128×128×64 │
│ │ │
│ ▼ │
│ Stage 2: MV2 Block ×1 + MobileViT Block ×1 │
│ └→ 128×128×64 → 64×64×96 │
│ Output: 64×64×96 │
│ │ │
│ ▼ │
│ Stage 3: MV2 Block ×1 + MobileViT Block ×1 │
│ └→ 64×64×96 → 32×32×128 │
│ Output: 32×32×128 │
│ │ │
│ ▼ │
│ Stage 4: MV2 Block ×1 + MobileViT Block ×1 │
│ └→ 32×32×128 → 16×16×192 │
│ Output: 16×16×192 │
│ │ │
│ ▼ │
│ Stage 5: MV2 Block ×1 │
│ Output: 16×16×192 → 16×16×320 (expansion) │
│ │ │
│ ▼ │
│ 全局池化 + Linear + Softmax │
│ │
└─────────────────────────────────────────────────────────────┘
3.2 完整实现
class MobileViT(nn.Module):
"""MobileViT完整模型"""
def __init__(self, num_classes=1000):
super().__init__()
# Stage 0: 初始卷积
self.stem = nn.Sequential(
nn.Conv2d(3, 32, 3, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU6(),
)
# Stage 1
self.stage1 = nn.Sequential(
InvertedResidual(32, 64, stride=1, expand_ratio=4),
InvertedResidual(64, 64, stride=1, expand_ratio=4),
)
# Stage 2: 第一个MobileViT
self.stage2 = nn.Sequential(
InvertedResidual(64, 96, stride=2, expand_ratio=4),
MobileViTBlock(96, 96, patch_size=2, d_model=96,
transformer_depth=2, num_heads=4),
)
# Stage 3
self.stage3 = nn.Sequential(
InvertedResidual(96, 128, stride=2, expand_ratio=4),
MobileViTBlock(128, 128, patch_size=2, d_model=144,
transformer_depth=3, num_heads=4),
)
# Stage 4
self.stage4 = nn.Sequential(
InvertedResidual(128, 192, stride=2, expand_ratio=4),
MobileViTBlock(192, 192, patch_size=2, d_model=192,
transformer_depth=3, num_heads=4),
)
# Stage 5
self.stage5 = nn.Sequential(
InvertedResidual(192, 320, stride=1, expand_ratio=4),
)
# 分类头
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(320, num_classes)
)
def forward(self, x):
x = self.stem(x)
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.stage5(x)
x = self.head(x)
return x四、实验结果
4.1 ImageNet分类
| 模型 | 参数量 | MACs | Top-1 | Top-5 |
|---|---|---|---|---|
| MobileNetV2 | 3.4M | 300M | 72.0% | - |
| MobileNetV3-S | 5.4M | 219M | 75.2% | - |
| MobileViT-S | 5.6M | 1.9G | 78.4% | 94.0% |
| MobileViT-XS | 3.5M | 0.9G | 74.8% | - |
| EfficientNet-B0 | 5.3M | 390M | 77.1% | - |
| DeiT-T | 5.7M | 1.3G | 72.2% | 91.1% |
4.2 移动端推理速度
| 模型 | iPhone 12 (ms) | Pixel 6 (ms) | Snapdragon 865 (ms) |
|---|---|---|---|
| MobileNetV3-S | 8.2 | 12.1 | 10.5 |
| MobileViT-S | 10.1 | 14.8 | 12.2 |
| EfficientNet-B0 | 25.3 | 32.1 | 28.5 |
4.3 目标检测性能
| Backbone | Detector | mAP | Params |
|---|---|---|---|
| MobileNetV3 | YOLOv4-Tiny | 24.3 | 5.8M |
| MobileViT-S | YOLOv4-Tiny | 28.1 | 8.2M |
| MobileNetV2 | RetinaNet | 31.4 | 8.5M |
| MobileViT-S | RetinaNet | 34.8 | 11.2M |
五、与其他轻量级ViT的对比
5.1 架构对比
| 特性 | MobileViT | EfficientViT | Twins |
|---|---|---|---|
| 局部建模 | MV2+卷积 | 多尺度注意力 | LSA |
| 全局建模 | Transformer | Transformer | GSA |
| 参数共享 | 无 | 无 | 无 |
| 位置编码 | 可学习 | 可学习 | CPE |
| 适用场景 | 移动端 | 高效推理 | 平衡性能 |
5.2 MobileViT的优势
- 轻量级:专为移动端设计,参数量和计算量都较小
- 高性能:在ImageNet上达到78%+的准确率
- 通用性:可用于分类、检测、分割等多种任务
- 简单性:架构相对简单,易于实现和部署
5.3 MobileViT的局限性
- Transformer深度有限:移动端资源限制导致不能堆叠太多层
- Patch大小固定:Patch大小选择影响性能和效率的权衡
- 位置编码:简单的位置编码可能不如复杂的位置编码有效
六、关键洞察
6.1 局部-全局表示学习的重要性
MobileViT的核心洞察是:好的视觉表示需要同时捕获局部和全局信息
- 局部表示:由卷积捕获,适用于边缘、纹理等低级特征
- 全局表示:由Transformer捕获,适用于物体结构、关系等高级特征
6.2 移动端设计的权衡
┌─────────────────────────────────────────────────────────────┐
│ 移动端设计的核心权衡 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 参数量 ◄──────────────────────────────────────► 精度 │
│ │ │ │
│ │ │ │
│ │ MobileViT │ │
│ │ ◄──────── 平衡点 ───────► │ │
│ │ │ │
│ ▼ ▼ │
│ MobileNet 大ViT (ViT-L, Swin) │
│ │
└─────────────────────────────────────────────────────────────┘
6.3 与知识蒸馏的结合
MobileViT常与知识蒸馏结合使用:
class DistillationLoss(nn.Module):
"""MobileViT知识蒸馏损失"""
def __init__(self, teacher, alpha=0.5, temperature=3.0):
super().__init__()
self.teacher = teacher
self.alpha = alpha
self.temperature = temperature
def forward(self, student_logits, targets):
# Hard loss
hard_loss = F.cross_entropy(student_logits, targets)
# Soft loss (知识蒸馏)
with torch.no_grad():
teacher_logits = self.teacher(inputs)
soft_loss = F.kl_div(
F.log_softmax(student_logits / self.temperature, dim=-1),
F.softmax(teacher_logits / self.temperature, dim=-1),
reduction='batchmean'
) * (self.temperature ** 2)
return self.alpha * hard_loss + (1 - self.alpha) * soft_loss七、MobileViT v2改进
MobileViT v2进一步优化了架构:
- 移除注意力中的残差:提高计算效率
- 线性复杂度的自注意力:使用线性注意力替代标准注意力
- 更高效的特征融合:改进上采样和下采样
class MobileViTv2Block(nn.Module):
"""MobileViT v2核心模块"""
def __init__(self, in_channels, out_channels, d_model=128):
super().__init__()
# 线性复杂度注意力
self.linear_attention = LinearAttention(d_model)
# 简化的投影
self.proj_in = nn.Conv2d(in_channels, d_model, 1)
self.proj_out = nn.Conv2d(d_model, out_channels, 1)
def forward(self, x):
x = self.proj_in(x)
B, C, H, W = x.shape
# 线性注意力
x = x.reshape(B, C, -1).transpose(1, 2) # [B, HW, C]
x = self.linear_attention(x)
x = x.transpose(1, 2).reshape(B, C, H, W)
return self.proj_out(x)