Transformer数学基础

Transformer架构的核心是自注意力机制(Self-Attention),它通过并行计算序列中所有位置之间的依赖关系,彻底解决了传统RNN的序列依赖问题。本章深入剖析Transformer的数学原理。

Self-Attention机制

核心公式

标准Scaled Dot-Product Attention定义为:

逐步骤推导

Step 1:QKV投影

设输入序列

其中:

  • :可学习的投影矩阵
  • :查询、键、值矩阵

Step 2:计算注意力分数

矩阵元素 表示第 个位置对第 个位置的注意力权重(未归一化)。

Step 3:缩放

为什么需要缩放?

假设 的各分量是均值为0、方差为1的独立随机变量,则:

因此 的方差与 成正比。当 较大时,点积的量级会很大,导致Softmax进入饱和区域(梯度趋近于0)。

缩放因子 将点积的方差归一化为1。

Step 4:Softmax归一化

其中 按行归一化:

Step 5:加权求和

输出是值的加权平均,权重由注意力矩阵决定。

计算复杂度分析

操作时间复杂度空间复杂度
投影
Softmax
总计

核心瓶颈 的空间和时间复杂度是处理长序列的主要障碍。

Softmax数值稳定性

问题分析

Softmax函数为:

很大时, 可能上溢为 ;当 很小时, 下溢为 0,导致数值不稳定。

Log-Sum-Exp技巧

数学恒等式:

减去最大值后,所有指数都在 范围内,避免上溢。

PyTorch实现

import torch
import torch.nn.functional as F
 
def stable_softmax(logits, dim=-1):
    """数值稳定的softmax实现"""
    # 减去最大值
    logits_minus_max = logits - logits.max(dim=dim, keepdim=True).values
    exp_logits = torch.exp(logits_minus_max)
    return exp_logits / exp_logits.sum(dim=dim, keepdim=True)
 
# PyTorch内置实现已经是数值稳定的
probs = F.softmax(logits, dim=-1)

多头注意力

数学定义

其中每个注意力头:

多头的几何意义

视角解释
子空间分解每个头在不同的 维子空间中计算注意力
多关系建模不同头捕获不同类型的依赖关系(句法、语义、位置等)
特征解耦允许头之间学习独立的信息流
信息融合 融合各头的输出

参数量分析

参数数量
总计约360K参数(与单头 相当)

PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.scale = math.sqrt(self.d_k)
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, mask=None):
        B = query.size(0)
        
        # 线性投影并分头: (B, n, d_model) -> (B, h, n, d_k)
        Q = self.W_q(query).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 注意力计算: (B, h, n, n)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # 加权求和: (B, h, n, d_k) -> (B, n, d_model)
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(B, -1, self.d_model)
        
        return self.W_o(context)

位置编码

为什么需要位置编码

自注意力机制是置换不变的:打乱输入序列的顺序,输出不变。这与语言/图像的顺序敏感性矛盾,因此需要注入位置信息。

Sinusoidal位置编码(原始Transformer)

Vaswani et al. (2017) 提出:

核心性质

性质公式意义
唯一性 for 每个位置有唯一编码
有界性防止数值问题
可推广性任意位置可外推无需学习所有位置

相对位置的几何表示

利用三角恒等式:

这意味着 可由 通过旋转矩阵得到,因此Sinusoidal编码隐式编码了相对位置信息。

PyTorch实现

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

Rotary Position Embedding (RoPE)

设计目标

RoPE (Su et al., 2022) 寻找函数 ,使得:

即:通过绝对位置编码实现相对位置感知

2D情况推导

设旋转矩阵:

则:

关键性质:旋转不改变点积的大小关系,只改变方向。这意味着相对位置信息被编码在点积中。

n维推广

将向量分成 对,对每对应用独立的2D旋转:

高效实现

使用逐元素运算替代矩阵乘法:

def apply_rotary_pos_emb(x, cos, sin):
    """
    x: (batch, heads, seq_len, d_k) where d_k is even
    """
    d_k = x.shape[-1]
    x1 = x[..., :d_k//2]  # 奇数维度
    x2 = x[..., d_k//2:]  # 偶数维度
    
    # 旋转:x' = x * cos + (-x2, x1) * sin
    return torch.cat([
        x1 * cos - x2 * sin,
        x1 * sin + x2 * cos
    ], dim=-1)

RoPE vs Sinusoidal

特性SinusoidalRoPE
编码方式加到embedding旋转Q/K向量
相对位置隐式显式
KV Cache兼容需重新计算✅ 自然兼容
长上下文需外推方法NTK-Scaling等

高效注意力机制

FlashAttention核心思想

IO-Aware设计:减少GPU HBM(高带宽内存)和SRAM之间的数据传输。

标准Attention:
1. 计算完整 QK^T → 存HBM (O(n²)空间)
2. Softmax → 存HBM
3. 乘V → 存HBM
4. 输出 → 存HBM

FlashAttention (Tiling):
1. 将Q/K/V分块读入SRAM
2. 逐块计算局部attention
3. 在线更新最终结果
4. 无需存储完整attention矩阵

FlashAttention数学等价性

FlashAttention精确等价于标准attention,无近似误差。

复杂度改进

指标标准FlashAttention
时间(相同)
HBM访问
显存

FlashAttention-2优化

  • 更好的warp分工
  • 更大的block size
  • 更高效的softmax
# 使用FlashAttention
from flash_attn import flash_attn_func
 
# Q, K, V: (batch, seq_len, num_heads, head_dim)
output = flash_attn_func(Q, K, V, causal=True)

Transformer编码器与解码器

编码器结构

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Pre-norm架构(优于原始post-norm)
        x = x + self.dropout(self.self_attn(self.norm1(x), mask))
        x = x + self.dropout(self.ffn(self.norm2(x)))
        return x

解码器结构

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # 掩码自注意力
        x = x + self.self_attn(self.norm1(x), tgt_mask)
        # 交叉注意力
        x = x + self.cross_attn(self.norm2(x), encoder_output, src_mask)
        # FFN
        x = x + self.ffn(self.norm3(x))
        return x

Pre-Norm vs Post-Norm

架构公式特点
Post-Norm(原始)训练不稳定,需warmup
Pre-Norm(现代)训练更稳定,效果更好

研究表明,Pre-Norm在深层网络中能更好地保持梯度稳定。1

参考


相关词条:Transformer与注意力机制LLM理论Transformer演进

Footnotes

  1. Nguyen, Salazar. “Transformers without Normalization”. 2023.