ViT-5:面向2020年代中期的现代Vision Transformer

概述

ViT-5(Vision Transformers for The Mid-2020s)是对原始Vision Transformer架构进行系统性现代化升级的研究工作。该工作并非提出全新的架构范式,而是基于过去五年的研究积累,对ViT的各个组件进行组件级(component-wise)精细化改进

核心论文:arXiv:2602.080711

代码实现GitHub - ViT-5


1. 背景与动机

Vision Transformer的发展历程

Vision Transformer(ViT)自2020年提出以来,已成为计算机视觉领域的基础架构。然而,原始ViT的设计在许多方面已经落后于最新的研究成果:

时间线:
2020 ── ViT ── 原始架构,Data-efficient Image Transformer
2021 ── DeiT ── 知识蒸馏,数据效率提升
2022 ── DeiT-II/III ── 训练策略优化
2023 ── DeiT-IV ── 持续改进
2024 ── 各种变体 ── Swin、ConvNeXt等
2026 ── ViT-5 ── 系统性现代化

现有改进的碎片化问题

过去几年,研究社区提出了大量改进ViT的方法:

改进方向方法发表年份
归一化LayerNorm → RMSNorm/GN2021-2023
激活函数GELU → SiLU/Swish2021-2022
位置编码可学习 → RoPE/ALiBi2022-2023
门控机制无 → 门控FFN2022-2023
令牌设计单一CLS → 可学习令牌2021-2023

问题:这些改进分散在不同论文中,缺乏统一的整合和系统评估。

ViT-5的目标

“While preserving the canonical Attention-FFN structure, we conduct a component-wise refinement”

ViT-5的核心理念:

  1. 保持ViT的简洁性——不引入复杂机制(如Swin的窗口移位)
  2. 系统整合——将过去五年的改进统一整合
  3. 组件级优化——对每个组件进行精细化改进

2. 核心改进:组件级现代化

2.1 归一化层现代化

原始ViT的问题

// 原始ViT使用标准LayerNorm
class OriginalViTNorm {
    // LayerNorm实现
    // 问题:计算开销较大,包含均值和方差计算
};

ViT-5的改进

import torch
import torch.nn as nn
 
class ViT5NormConfig:
    """ViT-5归一化配置"""
    
    # 选项1:RMSNorm(更高效)
    # 移除均值计算,仅计算RMS
    def rms_norm(x, normalized_shape, weight=None, eps=1e-6):
        """
        RMSNorm:Root Mean Square Layer Normalization
        
        优势:
        - 移除均值计算,降低计算复杂度
        - 30-40%的归一化层加速
        - 在视觉任务上效果相当
        """
        rms = torch.sqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + eps)
        x_norm = x / rms
        if weight is not None:
            x_norm = weight * x_norm
        return x_norm
    
    # 选项2:GroupNorm(对小batch更稳定)
    # 在通道维度上分组归一化
    def group_norm(x, num_groups, num_channels, eps=1e-6):
        """
        GroupNorm:分组归一化
        
        优势:
        - 对batch size不敏感
        - 训练稳定性更好
        - 适合视觉任务
        """
        assert num_channels % num_groups == 0
        x = x.view(x.size(0), num_groups, num_channels // num_groups, -1)
        mean = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], unbiased=False, keepdim=True)
        x = (x - mean) / torch.sqrt(var + eps)
        x = x.view(x.size(0), num_channels, *x.shape[2:])
        return x

2.2 激活函数现代化

原始ViT:GELU

# 原始ViT使用GELU
class OriginalActivation:
    gelu = nn.GELU()
    # GELU: x * Phi(x),其中Phi是标准正态CDF
    # 计算开销较大,涉及erf函数

ViT-5:更现代的激活

class ViT5Activation(nn.Module):
    """
    ViT-5使用的现代激活函数
    基于过去五年的实证研究选择
    """
    
    def __init__(self, activation_type='silu'):
        super().__init__()
        self.activation_type = activation_type
        
        if activation_type == 'silu':
            # SiLU/Swish: x * sigmoid(x)
            # 优势:自门控特性,训练更稳定
            self.act = nn.SiLU()
        elif activation_type == 'gelu_tanh':
            # GELU-Tanh近似(用于加速)
            self.act = nn.GELU(approximate='tanh')
        else:
            self.act = nn.GELU()
    
    def forward(self, x):
        return self.act(x)
 
# 门控激活机制
class GatedActivation(nn.Module):
    """
    门控激活:f(x) = gate(x) * main(x)
    
    优势:
    - 信息流更可控
    - 减少表示崩溃
    - 更好的梯度流
    """
    def __init__(self, d_model, d_ff, activation='silu'):
        super().__init__()
        self.gate_proj = nn.Linear(d_model, d_ff)
        self.value_proj = nn.Linear(d_model, d_ff)
        self.out_proj = nn.Linear(d_ff, d_model)
        self.act = nn.SiLU() if activation == 'silu' else nn.GELU()
    
    def forward(self, x):
        gate = self.act(self.gate_proj(x))
        value = self.value_proj(x)
        return self.out_proj(gate * value)

2.3 位置编码现代化

原始ViT的可学习位置编码

# 原始ViT使用可学习的位置编码
class OriginalViTPositionalEncoding:
    def __init__(self, seq_len, d_model):
        self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, d_model))
        # 问题:位置编码长度固定,难以泛化到训练时未见过的长度

ViT-5的改进

class ViT5PositionalEncoding(nn.Module):
    """
    ViT-5位置编码:结合多种现代技术
    """
    
    def __init__(self, d_model, max_len=2048, rope_type='rotary'):
        super().__init__()
        self.rope_type = rope_type
        
        if rope_type == 'rotary':
            # RoPE:旋转位置编码
            # 优势:无需额外参数,可外推到更长序列
            self.rope = RotaryPositionalEmbedding(d_model, max_len)
        elif rope_type == 'alibi':
            # ALiBi:注意力链式偏置
            # 优势:无需位置编码,自然支持更长序列
            self.alibi = ALiBiAttentionBias(d_model)
        elif rope_type == 'learned':
            # 可学习 + 插值策略
            self.pos_embed = nn.Parameter(torch.zeros(1, max_len, d_model))
    
    def forward(self, x, seq_len=None):
        if self.rope_type == 'rotary':
            return self.rope(x)
        elif self.rope_type == 'alibi':
            return self.alibi.get_bias(seq_len)
        else:
            return self.pos_embed[:, :seq_len]
 
 
class RotaryPositionalEmbedding(nn.Module):
    """
    旋转位置编码(RoPE)
    
    核心思想:将位置信息编码为旋转矩阵
    """
    def __init__(self, dim, max_len=2048):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        
        # 预计算旋转矩阵
        t = torch.arange(max_len).type_as(self.inv_freq)
        freqs = torch.einsum('n,i->ni', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer('cos_cached', emb.cos())
        self.register_buffer('sin_cached', emb.sin())
    
    @torch.jit.script
    def rotate_half(x):
        """将x分成两半并旋转"""
        x1 = x[..., :x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2:]
        return torch.cat((-x2, x1), dim=-1)
    
    def forward(self, q, k):
        # 应用旋转
        q_embed = (q * self.cos_cached[:q.shape[1]]) + \
                  (self.rotate_half(q) * self.sin_cached[:q.shape[1]])
        k_embed = (k * self.cos_cached[:k.shape[1]]) + \
                  (self.rotate_half(k) * self.sin_cached[:k.shape[1]])
        return q_embed, k_embed

2.4 门控机制

class ViT5AttentionGate(nn.Module):
    """
    注意力门控:增强信息流控制
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Sigmoid()  # 门控值在(0, 1)
        )
    
    def forward(self, attn_output, hidden_state):
        """
        门控后的注意力输出
        """
        gate_values = self.gate(hidden_state)
        return gate_values * attn_output
 
class ViT5FFNGate(nn.Module):
    """
    FFN门控:减少表示崩溃
    """
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.SiLU(),
            nn.Linear(d_ff, d_model),
            nn.Sigmoid()
        )
    
    def forward(self, ffn_output, hidden_state):
        return self.gate(hidden_state) * ffn_output

2.5 可学习令牌设计

class ViT5LearnableTokens(nn.Module):
    """
    ViT-5的可学习令牌设计
    替代单一CLS token的改进
    """
    def __init__(self, n_tokens, d_model):
        super().__init__()
        # 多个可学习令牌
        self.cls_tokens = nn.Parameter(torch.randn(1, n_tokens, d_model))
        # 可选的蒸馏令牌
        self.dist_token = nn.Parameter(torch.randn(1, 1, d_model))
    
    def forward(self, x):
        """
        Args:
            x: [batch, seq_len, d_model] - patch嵌入
        Returns:
            tokens: [batch, n_tokens + 1 + dist, d_model]
        """
        batch_size = x.shape[0]
        
        # 复制cls tokens到batch
        cls_tokens = self.cls_tokens.expand(batch_size, -1, -1)
        
        # 拼接
        x = torch.cat([
            cls_tokens,  # [batch, n_tokens, d_model]
            x,           # [batch, seq_len, d_model]
        ], dim=1)
        
        return x
    
    def aggregate(self, x, n_tokens):
        """
        聚合多个令牌的输出
        """
        cls_out = x[:, :n_tokens]  # [batch, n_tokens, d_model]
        # 可使用平均、注意力加权等方式聚合
        return cls_out.mean(dim=1)  # 简单平均

3. 完整ViT-5架构

import torch
import torch.nn as nn
from functools import partial
 
class ViT5Block(nn.Module):
    """ViT-5 Transformer块"""
    
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int = None,
        dropout: float = 0.0,
        drop_path: float = 0.0,
        activation: str = 'silu',
        norm_type: str = 'rmsnorm',
        use_gate: bool = True,
    ):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        
        # 归一化
        if norm_type == 'rmsnorm':
            self.norm1 = RMSNorm(d_model)
            self.norm2 = RMSNorm(d_model)
        elif norm_type == 'layernorm':
            self.norm1 = nn.LayerNorm(d_model)
            self.norm2 = nn.LayerNorm(d_model)
        else:
            self.norm1 = nn.GroupNorm(d_model // 32, d_model)
            self.norm2 = nn.GroupNorm(d_model // 32, d_model)
        
        # 注意力
        self.attn = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout, batch_first=True
        )
        
        # 门控
        if use_gate:
            self.attn_gate = ViT5AttentionGate(d_model, n_heads)
        
        # FFN
        self.ffn = GatedFFN(d_model, d_ff, activation)
        if use_gate:
            self.ffn_gate = ViT5FFNGate(d_model, d_ff)
        
        # DropPath
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    
    def forward(self, x):
        # 注意力残差连接
        h = self.norm1(x)
        attn_out, _ = self.attn(h, h, h)
        
        if hasattr(self, 'attn_gate'):
            attn_out = self.attn_gate(attn_out, h)
        
        x = x + self.drop_path(attn_out)
        
        # FFN残差连接
        h = self.norm2(x)
        ffn_out = self.ffn(h)
        
        if hasattr(self, 'ffn_gate'):
            ffn_out = self.ffn_gate(ffn_out, h)
        
        x = x + self.drop_path(ffn_out)
        
        return x
 
 
class VisionTransformer5(nn.Module):
    """完整ViT-5模型"""
    
    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        in_channels: int = 3,
        n_classes: int = 1000,
        d_model: int = 768,
        n_heads: int = 12,
        n_layers: int = 12,
        d_ff: int = None,
        dropout: float = 0.0,
        drop_path_rate: float = 0.1,
        norm_type: str = 'rmsnorm',
        activation: str = 'silu',
        use_gate: bool = True,
        n_learnable_tokens: int = 1,
        use_rope: bool = True,
    ):
        super().__init__()
        
        # Patch嵌入
        self.patch_embed = nn.Conv2d(
            in_channels, d_model,
            kernel_size=patch_size,
            stride=patch_size
        )
        
        # 位置编码
        n_patches = (img_size // patch_size) ** 2
        if use_rope:
            self.pos_encoding = None  # 使用RoPE
        else:
            self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, d_model))
        
        # 可学习令牌
        self.cls_tokens = ViT5LearnableTokens(n_learnable_tokens, d_model)
        
        # Transformer块
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
        self.blocks = nn.ModuleList([
            ViT5Block(
                d_model, n_heads, d_ff, dropout,
                drop_path=dpr[i],
                norm_type=norm_type,
                activation=activation,
                use_gate=use_gate,
            )
            for i in range(n_layers)
        ])
        
        # 输出头
        self.norm = RMSNorm(d_model) if norm_type == 'rmsnorm' else nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, n_classes)
        
        # 初始化
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
    
    def forward(self, x):
        # Patch嵌入
        x = self.patch_embed(x).flatten(2).transpose(1, 2)  # [B, N, D]
        
        # 添加位置编码
        if self.pos_embed is not None:
            x = x + self.pos_embed
        
        # 添加可学习令牌
        x = self.cls_tokens(x)
        
        # Transformer块
        for block in self.blocks:
            x = block(x)
        
        # 输出
        x = self.norm(x)
        cls_out = self.cls_tokens.aggregate(x, n_tokens=self.cls_tokens.cls_tokens.shape[1])
        
        return self.head(cls_out)

4. 训练配置

4.1 优化器与学习率

class ViT5TrainingConfig:
    """
    ViT-5训练配置
    基于大规模视觉模型训练的最佳实践
    """
    
    # 优化器:FusedLAMB
    optimizer = {
        'type': 'FusedLAMB',
        'lr': 1e-3,
        'weight_decay': 0.05,
        'beta1': 0.9,
        'beta2': 0.999,
        'eps': 1e-8,
    }
    
    # 学习率调度
    scheduler = {
        'type': 'CosineAnnealingLR',
        'T_max': 300,
        'eta_min': 1e-6,
    }
    
    # 数据增强
    augmentation = {
        'mixup_alpha': 0.8,
        'cutmix_alpha': 1.0,
        'color_jitter': 0.3,
        'three_augment': True,
        'repeated_aug': True,  # 重复增强
    }
    
    # 正则化
    regularization = {
        'drop_path': 0.05,  # 小模型
        # 0.35 for large model
        'label_smoothing': 0.0,  # 预训练
        'label_smoothing': 0.1,  # 微调
    }
    
    # 其他
    warmup_epochs = 5
    epochs = 300
    batch_size = 4096  # 大batch训练

4.2 训练策略对比

配置原始ViTDeiTViT-5
优化器AdamWAdamWFusedLAMB
Batch Size409610244096
权重衰减0.10.050.05
标签平滑0.00.10.0/0.1
DropPath0.00.0-0.10.05-0.35

5. 实验结果

5.1 模型变体

模型分辨率参数量ImageNet-1K Top-1
ViT-5-S22422M82.2%
ViT-5-B22487M84.2%
ViT-5-B38487M85.4%
ViT-5-L224304M84.9%
ViT-5-L384304M86.0%

5.2 与Swin Transformer对比

模型Swin-BViT-5-BSwin-LViT-5-L
参数量88M87M196M304M
ImageNet83.8%84.2%85.2%85.5%
推理速度基准1.1x基准0.9x

5.3 下游任务

任务Swin-BViT-5-B
COCO Detection48.5 mAP49.1 mAP
ADE20K Segmentation48.0 mIoU48.5 mIoU

6. 与ConvNeXt的对比

架构哲学差异

方面ViT-5ConvNeXt
核心机制全局自注意力大核卷积
设计理念保持ViT简洁性将CNN现代化
窗口机制无(全局注意力)无(纯卷积)
复杂性
可扩展性

适用场景

def choose_architecture():
    """
    ViT-5 vs ConvNeXt 选择指南
    """
    
    # 选择ViT-5的场景:
    # - 需要全局感受野
    # - 任务涉及长距离依赖
    # - 喜欢简洁架构
    # - 需要与ViT预训练模型兼容
    vit5_scenarios = [
        "图像分类(高分辨率)",
        "语义分割",
        "多尺度目标检测",
        "需要预训练ViT迁移"
    ]
    
    # 选择ConvNeXt的场景:
    # - 硬件对卷积友好
    # - 需要高效率推理
    # - 密集预测任务
    convnext_scenarios = [
        "实时应用",
        "边缘设备部署",
        "工业检测",
        "视频理解"
    ]
    
    return vit5_scenarios, convnext_scenarios

7. 总结

ViT-5的核心启示

  1. 渐进式改进的有效性:不需要全新的架构范式,系统性的组件升级同样有效
  2. 保持简洁性:避免引入不必要的复杂性(如Swin的窗口机制)
  3. 研究积累的价值:将分散的改进整合可以带来显著提升
  4. ViT的持久生命力:经过现代化升级,ViT仍是最先进的架构之一

未来方向

  • 更高效的注意力机制
  • 自适应计算分配
  • 多模态融合

参考资料


相关专题ViT详解 | Swin Transformer | ConvNeXt

Footnotes

  1. Wang et al. (2026). ViT-5: Vision Transformers for The Mid-2020s. arXiv:2602.08071