引言
Vision Transformer(ViT)1 是将自然语言处理领域的 Transformer 架构成功迁移到计算机视觉领域的里程碑式工作。与传统卷积神经网络(CNN)相比,ViT 通过全局自注意力机制直接建模图像块之间的依赖关系,摆脱了对局部卷积操作的结构化归纳偏置。
本文深入解析 ViT 的核心设计,包括 Patch Embedding、位置编码、Transformer Encoder 结构、训练策略,以及代表性变体架构。
核心思想:从 NLP 到 CV 的迁移
Transformer 在 NLP 中的成功
Transformer 最初在《Attention Is All You Need》2 中提出,用于机器翻译任务。其核心组件包括:
- 多头自注意力(Multi-Head Self-Attention, MHSA):计算序列内所有位置之间的依赖关系
- 前馈神经网络(FFN):逐位置的非线性变换
- 残差连接与层归一化:稳定训练
迁移到视觉的挑战
将 Transformer 直接应用于图像面临一个关键问题:序列长度。
| 输入形式 | NLP | CV |
|---|---|---|
| 典型长度 | 512-4096 tokens | 224×224 → 50,176 pixels |
| 处理方式 | 直接处理 | 需要分块(Patch) |
一张 的图像若按像素展平,将产生超过 50,000 个 token,自注意力的计算复杂度为 ,这在计算上是不可行的。
ViT 的解决思路
ViT 的核心创新是将图像划分为固定大小的 Patch,将每个 Patch 视为一个 “token”:
┌─────────┬─────────┬─────────┐
│ Patch 1 │ Patch 2 │ Patch 3 │
├─────────┼─────────┼─────────┤
│ Patch 4 │ Patch 5 │ Patch 6 │ → [x₁, x₂, ..., x₉, x_cls] → Transformer
├─────────┼─────────┼─────────┤
│ Patch 7 │ Patch 8 │ Patch 9 │
└─────────┴─────────┴─────────┘
16×16 patches (with 224×224 input)
Patch Embedding:图像序列化
Patch 划分
给定输入图像 ,ViT 将其划分为 个 Patch:
其中 是 Patch 的大小(通常为 16)。以 ImageNet 标准配置为例:
| 配置 | 输入尺寸 | Patch 大小 | Patch 数量 | 展平后维度 |
|---|---|---|---|---|
| ViT-B | 224×224 | 16×16 | 196 | 768 (16×16×3) |
| ViT-L | 224×224 | 16×16 | 196 | 1024 |
| ViT-H | 224×224 | 16×16 | 196 | 1280 |
线性投影
每个 Patch 随后通过一个线性投影层映射到隐空间:
其中:
- :Patch 嵌入矩阵
- :第 个 Patch 的位置编码
- :隐层维度
PyTorch 实现
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
"""图像转Patch嵌入 + 可学习位置编码"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# 线性投影层
self.proj = nn.Conv2d(
in_chans, embed_dim,
kernel_size=patch_size,
stride=patch_size
)
# 可学习位置编码
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1, embed_dim) # +1 for [CLS]
)
# [CLS] token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
B = x.shape[0]
# Patch 嵌入: [B, C, H, W] → [B, D, H/P, W/P] → [B, N, D]
x = self.proj(x) # [B, D, H/P, W/P]
x = x.flatten(2).transpose(1, 2) # [B, N, D]
# 添加 [CLS] token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # [B, N+1, D]
# 添加位置编码
x = x + self.pos_embed
return x位置编码
位置编码是 Transformer 建模序列顺序信息的关键组件。ViT 中常用的位置编码方案包括:
3.1 可学习位置编码
最简单直接的方法,将位置编码作为可学习的参数:
优点:简单、灵活
缺点:需要足够的训练数据来学习位置关系;对长度外推能力有限
3.2 正弦位置编码(Sinusoidal PE)
沿用原版 Transformer 的正弦编码:
特点:
- 无需学习,支持任意长度的序列
- 编码包含多频率成分
3.3 相对位置编码
为每个 Patch 对之间的相对偏移编码:
其中 表示两个位置之间的相对距离。
优势:
- 自然处理不同尺寸图像
- 对平移更鲁棒
3.4 2D 相对位置编码
针对图像的 2D 结构,分解水平和垂直方向:
其中 和 分别是两个位置的行、列索引。
位置编码对比
| 方案 | 参数量 | 外推能力 | 2D 感知 | 计算成本 |
|---|---|---|---|---|
| 可学习 | 差 | ❌ | 低 | |
| 正弦 | 0 | 良 | ❌ | 低 |
| 相对 | 优 | ✅ | 中 | |
| 2D 相对 | 优 | ✅ | 中 |
Transformer Encoder 结构
整体架构
ViT 的 Encoder 由 个相同的 Transformer Block 组成,每个 Block 包含:
┌─────────────────────────────────────────┐
│ Transformer Block │
│ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Multi-Head │ │ Feed-Forward │ │
│ │ Self-Attention│ → │ Network │ │
│ └───────┬──────┘ └───────┬──────┘ │
│ │ │ │
│ ┌───────▼──────┐ ┌───────▼──────┐ │
│ │ Add & Layer │ │ Add & Layer │ │
│ │ Norm │ │ Norm │ │
│ └──────────────┘ └──────────────┘ │
│ │
└─────────────────────────────────────────┘
多头自注意力(MHSA)
给定输入序列 ,MHSA 的计算过程为:
Step 1: 线性投影
其中 。
Step 2: 缩放点积注意力
Step 3: 多头组合
其中 。
计算复杂度分析:
| 操作 | 复杂度 | 说明 |
|---|---|---|
| QKV 投影 | 线性 | |
| 注意力矩阵 | 平方于序列长度 | |
| 输出投影 | 线性 |
前馈神经网络(FFN)
每个 Block 还包含一个两层的前馈网络:
其中 通常为 GELU 激活函数。
FFN 的扩展比例 (MLP 隐层维度与 的比值)是 ViT 的一个重要超参数:
| 配置 | FFN 扩展比 | MLP 隐层维度 |
|---|---|---|
| ViT-S | 4 | 2048 |
| ViT-B | 4 | 3072 |
| ViT-L | 4 | 4096 |
| ViT-G | 4 | 8192 |
完整 Block 实现
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(drop)
)
def forward(self, x):
# 自注意力 + 残差
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
# FFN + 残差
x = x + self.mlp(self.norm2(x))
return x[CLS] Token 的作用与替代方案
[CLS] Token 的设计动机
ViT 沿用了 BERT 的 [CLS] Token 设计。在预训练阶段,[CLS] token 的最终表示用于图像分类:
直觉:由于自注意力的全局性质,[CLS] token 可以 “看到” 所有 Patch 的信息,从而聚合全局特征。
[CLS] Token 的理论分析
从信息流动的角度,[CLS] token 在第 层的表示为:
其中 是第 层的注意力权重累积矩阵。
这意味着 [CLS] token 的表示是所有 Patch 初始表示的加权平均,权重由注意力机制学习得到。
替代方案
全局平均池化(GAP)
直接对所有 Patch 表示做池化:
优势:无需额外的可学习 token
[GAP] + [CLS] 组合
一些变体同时使用 GAP 和 [CLS]:
注意力池化
通过可学习的注意力权重进行加权聚合:
混合架构:CNN Backbone + ViT
设计动机
纯 ViT 缺乏 CNN 的局部归纳偏置,导致:
- 数据效率低:需要大规模预训练
- 训练不稳定:小数据集容易过拟合
混合架构试图结合 CNN 的局部建模能力和 ViT 的全局建模能力。
典型架构
1. CNN 特征图 → ViT
使用 CNN(如 ResNet)的特征图替代 Patch Embedding:
class HybridViT(nn.Module):
def __init__(self, backbone, embed_dim):
super().__init__()
self.backbone = backbone # e.g., ResNet-50 (remove final FC)
self.proj = nn.Conv2d(backbone.feature_dim, embed_dim, 1)
self.transformer = TransformerEncoder(...)
def forward(self, x):
# CNN 提取特征
feat = self.backbone(x) # [B, C, H', W']
feat = self.proj(feat) # [B, D, H', W']
# 空间展平
B, D, H, W = feat.shape
feat = feat.flatten(2).transpose(1, 2) # [B, N, D]
# 添加 [CLS]
cls_tokens = ...
return self.transformer(...)2. 局部-全局混合
Stage 1: Conv → Conv → ...
Stage 2: Conv-Attention Hybrid → ...
Stage 3: Transformer → ...
3. 并行混合
如 ConViT5,同时使用卷积和注意力:
性能对比
| 架构 | ImageNet Top-1 | 参数量 | FLOPs |
|---|---|---|---|
| ViT-B/16 | 77.9% | 86M | 17.6B |
| ViT-L/16 | 76.5%* | 307M | - |
| Hybrid ViT-B | 84.5% | 86M | 17.6B |
| CoAtNet-0 | 85.1% | 25M | 4.6B |
*在 ImageNet-21K 上预训练
变体架构
DeiT:数据高效图像 Transformers
DeiT6 通过以下技术显著提升了 ViT 的数据效率:
1. 教师蒸馏
使用 ResNet 作为教师模型进行蒸馏:
其中 。
2. 数据增强策略
DeiT 使用了强大的数据增强:
- RandAugment
- CutMix
- Mixup
- Label Smoothing
3. 训练技巧
| 技巧 | 效果 |
|---|---|
| 温度 | 软化 logits |
| 强的正则化 | 防止过拟合 |
| 更多 epoch | 充分训练 |
性能对比
| 模型 | 数据集 | Top-1 | 训练时长 |
|---|---|---|---|
| ViT-B/16 | IN-21K→IN-1K | 85.3% | ~3 天 (TPU) |
| DeiT-B/16 | IN-1K | 83.4% | ~10 小时 (A100) |
| DeiT-B/16 ↓ | IN-1K | 81.8% | ~3 小时 |
CCT:Compact Vision Transformer
CCT7 通过以下设计减少 ViT 的计算量:
- Sequential Attention:替代全局注意力
- ConvStem:使用卷积作为初始嵌入
- 更小的 Patch:保持细粒度
其他变体
| 变体 | 核心改进 |
|---|---|
| ViT-FRCNN | Faster R-CNN 检测头 |
| T2T-ViT | Tokens-to-Token 渐进式聚合 |
| CeiT | CNN 增强的 Image Tokenizer |
| LocalViT | 局部注意力机制 |
训练策略与数据效率
预训练策略
| 策略 | 数据需求 | 计算需求 | 典型精度 |
|---|---|---|---|
| 从头训练(IN-1K) | 低 | 低 | ~80% |
| IN-21K 预训练 | 中 | 中 | ~85% |
| JFT-300M 预训练 | 高 | 高 | ~88%+ |
小数据集训练技巧
-
更强的正则化
- DropPath (Stochastic Depth)
- Label Smoothing
- Mixup/CutMix
-
更小的模型
- ViT-S (Small)
- ViT-T (Tiny)
-
更好的初始化
- 在大型数据集上预训练的特征初始化
- MLP-Stick 初始化
训练动态
ViT 的训练具有独特的动态特性:
训练早期 (0-20%):
- 注意力快速学习局部模式
- [CLS] token 逐渐聚合全局信息
训练中期 (20-70%):
- 特征质量稳步提升
- 损失函数快速下降
训练后期 (70-100%):
- 收敛到最优解
- 可能出现过拟合
PyTorch 完整实现
import torch
import torch.nn as nn
from torch.nn import Dropout, Linear, LayerNorm
class VisionTransformer(nn.Module):
"""完整的 Vision Transformer 实现"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
):
super().__init__()
self.num_classes = num_classes
self.embed_dim = embed_dim
# Patch Embedding
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim
)
num_patches = self.patch_embed.num_patches
# Transformer Encoder
self.blocks = nn.ModuleList([
TransformerBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate
)
for _ in range(depth)
])
self.norm = LayerNorm(embed_dim)
# 分类头
self.head = Linear(embed_dim, num_classes)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
# Patch Embedding + Position Encoding
x = self.patch_embed(x)
# Transformer Encoder
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
# 分类
cls_token = x[:, 0]
return self.head(cls_token)与 CNN 的对比总结
| 特性 | ViT | CNN |
|---|---|---|
| 感受野 | 全局(第一层) | 局部(逐层扩大) |
| 归纳偏置 | 少(需学习) | 多(局部性、平移不变) |
| 数据效率 | 低(需大规模数据) | 高 |
| 参数量 | 较大 | 较小 |
| 长距离建模 | 强 | 弱(需大核/深层) |
| 可解释性 | 注意力可视化 | 卷积核可视化 |
| 硬件友好度 | 中等 | 高 |
参考
Footnotes
-
Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021. https://arxiv.org/abs/2010.11929 ↩
-
Vaswani, A., et al. (2017). Attention Is All You Need. NeurIPS 2017. ↩
-
Dai, Z., et al. (2021). CoAtNet: Marrying Convolution and Attention for All Data. NeurIPS 2021. ↩
-
Wu, H., et al. (2021). CvT: Introducing Convolutions to Vision Transformers. ICCV 2021. ↩
-
d’Ascoli, S., et al. (2021). ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases. ICML 2021. ↩
-
Touvron, H., et al. (2021). Training data-efficient image transformers & distillation through attention. ICML 2021. ↩
-
Hassani, A., et al. (2021). Escaping the Big Data Paradigm with Compact Transformers. arXiv:2104.05704. ↩