引言

Vision Transformer(ViT)1 是将自然语言处理领域的 Transformer 架构成功迁移到计算机视觉领域的里程碑式工作。与传统卷积神经网络(CNN)相比,ViT 通过全局自注意力机制直接建模图像块之间的依赖关系,摆脱了对局部卷积操作的结构化归纳偏置。

本文深入解析 ViT 的核心设计,包括 Patch Embedding、位置编码、Transformer Encoder 结构、训练策略,以及代表性变体架构。


核心思想:从 NLP 到 CV 的迁移

Transformer 在 NLP 中的成功

Transformer 最初在《Attention Is All You Need》2 中提出,用于机器翻译任务。其核心组件包括:

  1. 多头自注意力(Multi-Head Self-Attention, MHSA):计算序列内所有位置之间的依赖关系
  2. 前馈神经网络(FFN):逐位置的非线性变换
  3. 残差连接与层归一化:稳定训练

迁移到视觉的挑战

将 Transformer 直接应用于图像面临一个关键问题:序列长度

输入形式NLPCV
典型长度512-4096 tokens224×224 → 50,176 pixels
处理方式直接处理需要分块(Patch)

一张 的图像若按像素展平,将产生超过 50,000 个 token,自注意力的计算复杂度为 ,这在计算上是不可行的。

ViT 的解决思路

ViT 的核心创新是将图像划分为固定大小的 Patch,将每个 Patch 视为一个 “token”:

┌─────────┬─────────┬─────────┐
│ Patch 1 │ Patch 2 │ Patch 3 │
├─────────┼─────────┼─────────┤
│ Patch 4 │ Patch 5 │ Patch 6 │    →    [x₁, x₂, ..., x₉, x_cls] → Transformer
├─────────┼─────────┼─────────┤
│ Patch 7 │ Patch 8 │ Patch 9 │
└─────────┴─────────┴─────────┘
     16×16 patches (with 224×224 input)

Patch Embedding:图像序列化

Patch 划分

给定输入图像 ,ViT 将其划分为 个 Patch:

其中 是 Patch 的大小(通常为 16)。以 ImageNet 标准配置为例:

配置输入尺寸Patch 大小Patch 数量展平后维度
ViT-B224×22416×16196768 (16×16×3)
ViT-L224×22416×161961024
ViT-H224×22416×161961280

线性投影

每个 Patch 随后通过一个线性投影层映射到隐空间:

其中:

  • :Patch 嵌入矩阵
  • :第 个 Patch 的位置编码
  • :隐层维度

PyTorch 实现

import torch
import torch.nn as nn
 
class PatchEmbed(nn.Module):
    """图像转Patch嵌入 + 可学习位置编码"""
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # 线性投影层
        self.proj = nn.Conv2d(
            in_chans, embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        
        # 可学习位置编码
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches + 1, embed_dim)  # +1 for [CLS]
        )
        
        # [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch 嵌入: [B, C, H, W] → [B, D, H/P, W/P] → [B, N, D]
        x = self.proj(x)  # [B, D, H/P, W/P]
        x = x.flatten(2).transpose(1, 2)  # [B, N, D]
        
        # 添加 [CLS] token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # [B, N+1, D]
        
        # 添加位置编码
        x = x + self.pos_embed
        
        return x

位置编码

位置编码是 Transformer 建模序列顺序信息的关键组件。ViT 中常用的位置编码方案包括:

3.1 可学习位置编码

最简单直接的方法,将位置编码作为可学习的参数:

优点:简单、灵活
缺点:需要足够的训练数据来学习位置关系;对长度外推能力有限

3.2 正弦位置编码(Sinusoidal PE)

沿用原版 Transformer 的正弦编码:

特点

  • 无需学习,支持任意长度的序列
  • 编码包含多频率成分

3.3 相对位置编码

为每个 Patch 对之间的相对偏移编码:

其中 表示两个位置之间的相对距离。

优势

  • 自然处理不同尺寸图像
  • 对平移更鲁棒

3.4 2D 相对位置编码

针对图像的 2D 结构,分解水平和垂直方向:

其中 分别是两个位置的行、列索引。

位置编码对比

方案参数量外推能力2D 感知计算成本
可学习
正弦0
相对
2D 相对

Transformer Encoder 结构

整体架构

ViT 的 Encoder 由 个相同的 Transformer Block 组成,每个 Block 包含:

┌─────────────────────────────────────────┐
│              Transformer Block            │
│                                          │
│  ┌──────────────┐   ┌──────────────┐    │
│  │  Multi-Head   │   │  Feed-Forward │    │
│  │  Self-Attention│ → │  Network      │    │
│  └───────┬──────┘   └───────┬──────┘    │
│          │                  │            │
│  ┌───────▼──────┐   ┌───────▼──────┐    │
│  │  Add & Layer  │   │  Add & Layer  │    │
│  │  Norm         │   │  Norm         │    │
│  └──────────────┘   └──────────────┘    │
│                                          │
└─────────────────────────────────────────┘

多头自注意力(MHSA)

给定输入序列 ,MHSA 的计算过程为:

Step 1: 线性投影

其中

Step 2: 缩放点积注意力

Step 3: 多头组合

其中

计算复杂度分析

操作复杂度说明
QKV 投影线性
注意力矩阵平方于序列长度
输出投影线性

前馈神经网络(FFN)

每个 Block 还包含一个两层的前馈网络:

其中 通常为 GELU 激活函数。

FFN 的扩展比例 (MLP 隐层维度与 的比值)是 ViT 的一个重要超参数:

配置FFN 扩展比MLP 隐层维度
ViT-S42048
ViT-B43072
ViT-L44096
ViT-G48192

完整 Block 实现

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(drop)
        )
    
    def forward(self, x):
        # 自注意力 + 残差
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        # FFN + 残差
        x = x + self.mlp(self.norm2(x))
        return x

[CLS] Token 的作用与替代方案

[CLS] Token 的设计动机

ViT 沿用了 BERT 的 [CLS] Token 设计。在预训练阶段,[CLS] token 的最终表示用于图像分类:

直觉:由于自注意力的全局性质,[CLS] token 可以 “看到” 所有 Patch 的信息,从而聚合全局特征。

[CLS] Token 的理论分析

从信息流动的角度,[CLS] token 在第 层的表示为:

其中 是第 层的注意力权重累积矩阵。

这意味着 [CLS] token 的表示是所有 Patch 初始表示的加权平均,权重由注意力机制学习得到。

替代方案

全局平均池化(GAP)

直接对所有 Patch 表示做池化:

优势:无需额外的可学习 token

[GAP] + [CLS] 组合

一些变体同时使用 GAP 和 [CLS]:

注意力池化

通过可学习的注意力权重进行加权聚合:


混合架构:CNN Backbone + ViT

设计动机

纯 ViT 缺乏 CNN 的局部归纳偏置,导致:

  1. 数据效率低:需要大规模预训练
  2. 训练不稳定:小数据集容易过拟合

混合架构试图结合 CNN 的局部建模能力和 ViT 的全局建模能力。

典型架构

1. CNN 特征图 → ViT

使用 CNN(如 ResNet)的特征图替代 Patch Embedding:

class HybridViT(nn.Module):
    def __init__(self, backbone, embed_dim):
        super().__init__()
        self.backbone = backbone  # e.g., ResNet-50 (remove final FC)
        self.proj = nn.Conv2d(backbone.feature_dim, embed_dim, 1)
        self.transformer = TransformerEncoder(...)
    
    def forward(self, x):
        # CNN 提取特征
        feat = self.backbone(x)  # [B, C, H', W']
        feat = self.proj(feat)   # [B, D, H', W']
        
        # 空间展平
        B, D, H, W = feat.shape
        feat = feat.flatten(2).transpose(1, 2)  # [B, N, D]
        
        # 添加 [CLS]
        cls_tokens = ...
        return self.transformer(...)

2. 局部-全局混合

如 CoAtNet3、CVT4 等,采用渐进式混合:

Stage 1: Conv → Conv → ...
Stage 2: Conv-Attention Hybrid → ...
Stage 3: Transformer → ...

3. 并行混合

如 ConViT5,同时使用卷积和注意力:

性能对比

架构ImageNet Top-1参数量FLOPs
ViT-B/1677.9%86M17.6B
ViT-L/1676.5%*307M-
Hybrid ViT-B84.5%86M17.6B
CoAtNet-085.1%25M4.6B

*在 ImageNet-21K 上预训练


变体架构

DeiT:数据高效图像 Transformers

DeiT6 通过以下技术显著提升了 ViT 的数据效率:

1. 教师蒸馏

使用 ResNet 作为教师模型进行蒸馏:

其中

2. 数据增强策略

DeiT 使用了强大的数据增强:

  • RandAugment
  • CutMix
  • Mixup
  • Label Smoothing

3. 训练技巧

技巧效果
温度 软化 logits
强的正则化防止过拟合
更多 epoch充分训练

性能对比

模型数据集Top-1训练时长
ViT-B/16IN-21K→IN-1K85.3%~3 天 (TPU)
DeiT-B/16IN-1K83.4%~10 小时 (A100)
DeiT-B/16 ↓IN-1K81.8%~3 小时

CCT:Compact Vision Transformer

CCT7 通过以下设计减少 ViT 的计算量:

  1. Sequential Attention:替代全局注意力
  2. ConvStem:使用卷积作为初始嵌入
  3. 更小的 Patch:保持细粒度

其他变体

变体核心改进
ViT-FRCNNFaster R-CNN 检测头
T2T-ViTTokens-to-Token 渐进式聚合
CeiTCNN 增强的 Image Tokenizer
LocalViT局部注意力机制

训练策略与数据效率

预训练策略

策略数据需求计算需求典型精度
从头训练(IN-1K)~80%
IN-21K 预训练~85%
JFT-300M 预训练~88%+

小数据集训练技巧

  1. 更强的正则化

    • DropPath (Stochastic Depth)
    • Label Smoothing
    • Mixup/CutMix
  2. 更小的模型

    • ViT-S (Small)
    • ViT-T (Tiny)
  3. 更好的初始化

    • 在大型数据集上预训练的特征初始化
    • MLP-Stick 初始化

训练动态

ViT 的训练具有独特的动态特性:

训练早期 (0-20%):
- 注意力快速学习局部模式
- [CLS] token 逐渐聚合全局信息

训练中期 (20-70%):
- 特征质量稳步提升
- 损失函数快速下降

训练后期 (70-100%):
- 收敛到最优解
- 可能出现过拟合

PyTorch 完整实现

import torch
import torch.nn as nn
from torch.nn import Dropout, Linear, LayerNorm
 
class VisionTransformer(nn.Module):
    """完整的 Vision Transformer 实现"""
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.,
        qkv_bias=True,
        drop_rate=0.,
        attn_drop_rate=0.,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        
        # Patch Embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim
        )
        num_patches = self.patch_embed.num_patches
        
        # Transformer Encoder
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate
            )
            for _ in range(depth)
        ])
        self.norm = LayerNorm(embed_dim)
        
        # 分类头
        self.head = Linear(embed_dim, num_classes)
        
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
    
    def forward(self, x):
        # Patch Embedding + Position Encoding
        x = self.patch_embed(x)
        
        # Transformer Encoder
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        
        # 分类
        cls_token = x[:, 0]
        return self.head(cls_token)

与 CNN 的对比总结

特性ViTCNN
感受野全局(第一层)局部(逐层扩大)
归纳偏置少(需学习)多(局部性、平移不变)
数据效率低(需大规模数据)
参数量较大较小
长距离建模弱(需大核/深层)
可解释性注意力可视化卷积核可视化
硬件友好度中等

参考

Footnotes

  1. Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021. https://arxiv.org/abs/2010.11929

  2. Vaswani, A., et al. (2017). Attention Is All You Need. NeurIPS 2017.

  3. Dai, Z., et al. (2021). CoAtNet: Marrying Convolution and Attention for All Data. NeurIPS 2021.

  4. Wu, H., et al. (2021). CvT: Introducing Convolutions to Vision Transformers. ICCV 2021.

  5. d’Ascoli, S., et al. (2021). ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases. ICML 2021.

  6. Touvron, H., et al. (2021). Training data-efficient image transformers & distillation through attention. ICML 2021.

  7. Hassani, A., et al. (2021). Escaping the Big Data Paradigm with Compact Transformers. arXiv:2104.05704.