Transformer数学基础
Transformer架构的核心是自注意力机制(Self-Attention),它通过并行计算序列中所有位置之间的依赖关系,彻底解决了传统RNN的序列依赖问题。本章深入剖析Transformer的数学原理。
Self-Attention机制
核心公式
标准Scaled Dot-Product Attention定义为:
逐步骤推导
Step 1:QKV投影
设输入序列 :
其中:
- :可学习的投影矩阵
- :查询、键、值矩阵
Step 2:计算注意力分数
矩阵元素 表示第 个位置对第 个位置的注意力权重(未归一化)。
Step 3:缩放
为什么需要缩放?
假设 和 的各分量是均值为0、方差为1的独立随机变量,则:
因此 的方差与 成正比。当 较大时,点积的量级会很大,导致Softmax进入饱和区域(梯度趋近于0)。
缩放因子 将点积的方差归一化为1。
Step 4:Softmax归一化
其中 按行归一化:
Step 5:加权求和
输出是值的加权平均,权重由注意力矩阵决定。
计算复杂度分析
| 操作 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| 投影 | ||
| Softmax | ||
| 总计 |
核心瓶颈: 的空间和时间复杂度是处理长序列的主要障碍。
Softmax数值稳定性
问题分析
Softmax函数为:
当 很大时, 可能上溢为 ;当 很小时, 下溢为 0,导致数值不稳定。
Log-Sum-Exp技巧
数学恒等式:
减去最大值后,所有指数都在 范围内,避免上溢。
PyTorch实现
import torch
import torch.nn.functional as F
def stable_softmax(logits, dim=-1):
"""数值稳定的softmax实现"""
# 减去最大值
logits_minus_max = logits - logits.max(dim=dim, keepdim=True).values
exp_logits = torch.exp(logits_minus_max)
return exp_logits / exp_logits.sum(dim=dim, keepdim=True)
# PyTorch内置实现已经是数值稳定的
probs = F.softmax(logits, dim=-1)多头注意力
数学定义
其中每个注意力头:
多头的几何意义
| 视角 | 解释 |
|---|---|
| 子空间分解 | 每个头在不同的 维子空间中计算注意力 |
| 多关系建模 | 不同头捕获不同类型的依赖关系(句法、语义、位置等) |
| 特征解耦 | 允许头之间学习独立的信息流 |
| 信息融合 | 融合各头的输出 |
参数量分析
设 ,,:
| 参数 | 数量 |
|---|---|
| 总计 | 约360K参数(与单头 相当) |
PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.scale = math.sqrt(self.d_k)
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
B = query.size(0)
# 线性投影并分头: (B, n, d_model) -> (B, h, n, d_k)
Q = self.W_q(query).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
# 注意力计算: (B, h, n, n)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 加权求和: (B, h, n, d_k) -> (B, n, d_model)
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).contiguous().view(B, -1, self.d_model)
return self.W_o(context)位置编码
为什么需要位置编码
自注意力机制是置换不变的:打乱输入序列的顺序,输出不变。这与语言/图像的顺序敏感性矛盾,因此需要注入位置信息。
Sinusoidal位置编码(原始Transformer)
Vaswani et al. (2017) 提出:
核心性质
| 性质 | 公式 | 意义 |
|---|---|---|
| 唯一性 | for | 每个位置有唯一编码 |
| 有界性 | 防止数值问题 | |
| 可推广性 | 任意位置可外推 | 无需学习所有位置 |
相对位置的几何表示
利用三角恒等式:
这意味着 可由 通过旋转矩阵得到,因此Sinusoidal编码隐式编码了相对位置信息。
PyTorch实现
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch_size, seq_len, d_model)
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)Rotary Position Embedding (RoPE)
设计目标
RoPE (Su et al., 2022) 寻找函数 和 ,使得:
即:通过绝对位置编码实现相对位置感知。
2D情况推导
设旋转矩阵:
则:
关键性质:旋转不改变点积的大小关系,只改变方向。这意味着相对位置信息被编码在点积中。
n维推广
将向量分成 对,对每对应用独立的2D旋转:
高效实现
使用逐元素运算替代矩阵乘法:
def apply_rotary_pos_emb(x, cos, sin):
"""
x: (batch, heads, seq_len, d_k) where d_k is even
"""
d_k = x.shape[-1]
x1 = x[..., :d_k//2] # 奇数维度
x2 = x[..., d_k//2:] # 偶数维度
# 旋转:x' = x * cos + (-x2, x1) * sin
return torch.cat([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos
], dim=-1)RoPE vs Sinusoidal
| 特性 | Sinusoidal | RoPE |
|---|---|---|
| 编码方式 | 加到embedding | 旋转Q/K向量 |
| 相对位置 | 隐式 | 显式 |
| KV Cache兼容 | 需重新计算 | ✅ 自然兼容 |
| 长上下文 | 需外推方法 | NTK-Scaling等 |
高效注意力机制
FlashAttention核心思想
IO-Aware设计:减少GPU HBM(高带宽内存)和SRAM之间的数据传输。
标准Attention:
1. 计算完整 QK^T → 存HBM (O(n²)空间)
2. Softmax → 存HBM
3. 乘V → 存HBM
4. 输出 → 存HBM
FlashAttention (Tiling):
1. 将Q/K/V分块读入SRAM
2. 逐块计算局部attention
3. 在线更新最终结果
4. 无需存储完整attention矩阵
FlashAttention数学等价性
FlashAttention精确等价于标准attention,无近似误差。
复杂度改进
| 指标 | 标准 | FlashAttention |
|---|---|---|
| 时间 | (相同) | |
| HBM访问 | ||
| 显存 |
FlashAttention-2优化
- 更好的warp分工
- 更大的block size
- 更高效的softmax
# 使用FlashAttention
from flash_attn import flash_attn_func
# Q, K, V: (batch, seq_len, num_heads, head_dim)
output = flash_attn_func(Q, K, V, causal=True)Transformer编码器与解码器
编码器结构
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Pre-norm架构(优于原始post-norm)
x = x + self.dropout(self.self_attn(self.norm1(x), mask))
x = x + self.dropout(self.ffn(self.norm2(x)))
return x解码器结构
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
# 掩码自注意力
x = x + self.self_attn(self.norm1(x), tgt_mask)
# 交叉注意力
x = x + self.cross_attn(self.norm2(x), encoder_output, src_mask)
# FFN
x = x + self.ffn(self.norm3(x))
return xPre-Norm vs Post-Norm
| 架构 | 公式 | 特点 |
|---|---|---|
| Post-Norm(原始) | 训练不稳定,需warmup | |
| Pre-Norm(现代) | 训练更稳定,效果更好 |
研究表明,Pre-Norm在深层网络中能更好地保持梯度稳定。1
参考
相关词条:Transformer与注意力机制,LLM理论,Transformer演进
Footnotes
-
Nguyen, Salazar. “Transformers without Normalization”. 2023. ↩