Transformer位置编码完整指南
位置编码(Positional Encoding, PE)是Transformer架构中不可或缺的核心组件,负责将序列中token的顺序信息注入模型。与RNN和CNN天然具有位置感知能力不同,Transformer的自注意力机制具有位置不变性——即输入token的排列顺序不影响注意力计算结果。1 本指南系统梳理位置编码的技术演进、数学推导与实践指南。
1. 位置编码基础
1.1 为什么需要位置编码
Transformer的核心是自注意力(Self-Attention)机制,其计算可表示为:
其中 、、 分别由输入嵌入通过线性变换得到。关键问题在于:输入嵌入本身不包含位置信息。如果将序列 打乱为 ,自注意力计算结果完全相同,因为矩阵乘法具有置换不变性:
这意味着Transformer在没有位置编码的情况下,是一个”词袋模型”(Bag-of-Words),无法区分”狗咬人”和”人咬狗”的语义差异。
1.2 自注意力的位置不变性
自注意力的置换不变性来源于其计算本质:注意力分数仅依赖于query和key向量的内积,而内积运算对输入顺序不敏感。
形式化定义:设 为任意置换,对于注意力输出有:
其中 为置换矩阵。这表明注意力输出会按相同方式置换,但计算本身与顺序无关。
位置编码的解决思路:向输入嵌入中添加位置信息 ,使得:
这样打乱输入顺序会导致 ,从而模型可以感知位置差异。
1.3 绝对vs相对位置编码
根据编码方式的不同,位置编码可分为两大类:
绝对位置编码(Absolute Positional Encoding)
编码每个位置的绝对索引 ,生成独立的嵌入向量 :
代表方法:可学习位置编码、Sinusoidal位置编码
特点:
- 每个位置有唯一标识
- 参数量随序列长度线性增长(可学习版本)
- 推理时可外推到更长序列(Sinusoidal版本)
相对位置编码(Relative Positional Encoding)
编码token之间的距离 ,而非绝对位置:
其中 是关于相对位置 的函数。
代表方法:T5 Relative Position Bias、Shaw’s Relative Position Embedding、RoPE、ALiBi
特点:
- 自然编码位置关系
- 对序列长度变换更鲁棒
- 更适合长度外推任务
| 维度 | 绝对位置编码 | 相对位置编码 |
|---|---|---|
| 编码对象 | 绝对索引 | 相对距离 |
| 参数效率 | 可学习版本需 参数 | 通常 或 |
| 外推能力 | Sinusoidal可外推 | 天然支持外推 |
| 位置关系建模 | 间接 | 直接 |
| 计算复杂度 | 较低 | 可较高 |
2. 绝对位置编码
2.1 可学习位置编码
可学习位置编码(Learned Positional Embedding)是最直觉的方案,将位置视为可训练的嵌入向量:
其中 是可学习参数矩阵, 是预设的最大序列长度。
实现代码:
import torch
import torch.nn as nn
class LearnedPositionalEmbedding(nn.Module):
"""
可学习位置编码
优点:简单直观,可端到端优化
缺点:无法外推到训练长度之外
"""
def __init__(self, max_len, d_model):
super().__init__()
# 位置嵌入矩阵:[max_len, d_model]
self.pe = nn.Embedding(max_len, d_model)
def forward(self, x):
"""
Args:
x: [batch_size, seq_len, d_model]
Returns:
带位置编码的嵌入
"""
batch_size, seq_len, _ = x.shape
# 生成位置索引 [0, 1, 2, ..., seq_len-1]
position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0)
position_ids = position_ids.expand(batch_size, -1)
# 查询位置嵌入
position_embeddings = self.pe(position_ids)
return x + position_embeddings特点分析:
| 特性 | 评价 |
|---|---|
| 表达能力 | 理论上可学习任意位置映射 |
| 参数量 | |
| 外推能力 | ❌ 无法外推到 之外 |
| 计算效率 | 首次调用需读取嵌入,后续高效 |
| 训练稳定性 | 依赖良好初始化 |
应用场景:适用于序列长度固定的场景,如分类任务的BERT。
2.2 Sinusoidal位置编码
Vaswani等人提出的Sinusoidal位置编码通过正弦余弦函数生成位置向量,无需学习参数。2
数学公式:对于位置 和维度 :
可统一写作:
实现代码:
import torch
import math
class SinusoidalPositionalEmbedding(nn.Module):
"""
Sinusoidal位置编码
特点:
- 无需学习参数
- 可外推到任意长度
- 位置向量具有周期性和衰减特性
"""
def __init__(self, d_model, max_len=5000):
super().__init__()
self.d_model = d_model
# 预计算编码矩阵(可选,节省运行时计算)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
# 计算除数项:10000^{2j/d}
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model)
)
# 填充偶数维度(sin)
pe[:, 0::2] = torch.sin(position * div_term)
# 填充奇数维度(cos)
pe[:, 1::2] = torch.cos(position * div_term)
# 注册为buffer(不参与梯度计算,但会随模型保存)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: [batch_size, seq_len, d_model]
Returns:
带位置编码的嵌入
"""
seq_len = x.size(1)
# 取前seq_len个位置编码
return x + self.pe[:seq_len]
@staticmethod
def encode_position(pos, d_model):
"""
动态计算单个位置的位置编码(适用于变长序列)
Args:
pos: 位置标量或张量
d_model: 模型维度
Returns:
位置编码向量
"""
pe = torch.zeros(pos.shape[0] if hasattr(pos, 'shape') else 1, d_model)
pe[:, 0::2] = torch.sin(pos / torch.pow(10000, 2 * torch.arange(0, d_model, 2) / d_model))
pe[:, 1::2] = torch.cos(pos / torch.pow(10000, 2 * torch.arange(1, d_model, 2) / d_model))
return pe2.3 Sinusoidal编码的性质
Sinusoidal位置编码具有若干优良性质,使其成为Transformer的默认选择。
2.3.1 周期性与频率特性
位置编码的不同维度对应不同的频率:
| 维度范围 | 周期 | 特性 |
|---|---|---|
| 低维度() | 长周期() | 编码全局位置 |
| 高维度() | 短周期() | 编码精细位置 |
这种多尺度设计使得模型可以从不同维度捕获从粗到细的位置信息。
2.3.2 相对位置编码性质
关键定理:Sinusoidal位置编码的任意两个位置的点积仅依赖于它们的相对距离。
证明:对于位置 和 ,令 ,则:
利用三角恒等式:
得:
这表明点积只依赖 ,即隐式编码了相对位置信息。
2.3.3 外推能力
Sinusoidal编码理论上可外推到任意长度,因为:
- 数学连续性:正弦函数是连续且周期性的
- 无参数依赖:不依赖预定义的最大长度
- 多尺度覆盖:不同频率组合可表示任意长距离
然而实践中,外推效果通常不如专门设计的外推方法(如ALiBi)。
2.3.4 与词嵌入的融合
Sinusoidal编码与词嵌入的融合方式为加性融合:
这种方式简单但存在耦合问题:位置信息和语义信息在加法操作中混合,可能干扰各自的学习。
3. 相对位置编码
相对位置编码直接编码token之间的距离,天然更适合建模位置关系,近年来成为主流方案。
3.1 T5 Relative Position Bias
T5模型提出的相对位置偏置是一种可学习的软偏置,直接加在注意力分数上。3
数学公式:
其中偏置项定义为:
将相对距离截断到 范围, 是可学习的偏置向量。
实现代码:
import torch
import torch.nn as nn
class T5RelativePositionBias(nn.Module):
"""
T5相对位置偏置
特点:
- 可学习的离散偏置表
- 截断机制限制参数增长
- 直接作用于注意力分数
"""
def __init__(self, d_model, num_heads, max_distance=128):
super().__init__()
self.num_heads = num_heads
self.max_distance = max_distance
# 偏置表:[-max_distance, ..., 0, ..., max_distance]
self.relative_attention_bias = nn.Embedding(2 * max_distance + 1, num_heads)
def _get_relative_positions(self, seq_len):
"""生成相对位置矩阵"""
# 生成绝对位置索引
position_ids = torch.arange(seq_len, device=self.relative_attention_bias.weight.device)
# 计算相对位置:position[i] - position[j]
relative_position = position_ids.unsqueeze(0) - position_ids.unsqueeze(1)
# 截断到 [-max_distance, max_distance]
relative_position = torch.clamp(
relative_position,
-self.max_distance,
self.max_distance
)
# 平移使索引从0开始
relative_position = relative_position + self.max_distance
return relative_position
def forward(self, query, key):
"""
Args:
query: [batch_size, num_heads, seq_len_q, d_k]
key: [batch_size, num_heads, seq_len_k, d_k]
Returns:
相对位置偏置矩阵
"""
seq_len_q = query.size(2)
seq_len_k = key.size(2)
relative_position = self._get_relative_positions(max(seq_len_q, seq_len_k))
# 获取偏置 [seq_len_q, seq_len_k, num_heads]
bias = self.relative_attention_bias(relative_position)
# 调整维度顺序:[seq_len_q, seq_len_k, num_heads] -> [num_heads, seq_len_q, seq_len_k]
bias = bias.permute(2, 0, 1)
return bias特点分析:
| 维度 | 分析 |
|---|---|
| 参数量 | , 为最大距离, 为头数 |
| 外推能力 | 有限,需预定义 |
| 表达灵活度 | 离散偏置,可学习任意映射 |
| 计算开销 | 轻微,需查表 |
3.2 Shaw’s Relative Position Embedding
Shaw等人提出的相对位置嵌入通过连续函数建模相对位置关系。4
核心思想:相对位置 被映射到连续向量空间:
其中嵌入函数设计为:
其中 控制衰减率。
实现代码:
import torch
import torch.nn as nn
import math
class ShawRelativePositionEmbedding(nn.Module):
"""
Shaw相对位置嵌入
特点:
- 有限窗口内精确编码
- 无限距离外指数衰减
- 支持相对位置感知注意力
"""
def __init__(self, d_model, max_distance=16, num_units=8, decay_rate=0.99):
super().__init__()
self.max_distance = max_distance
self.num_units = num_units # 每个位置的嵌入维度
self.decay_rate = decay_rate
# 嵌入维度为 num_units * 2 + 1(考虑正负位置)
self.embedding = nn.Parameter(
torch.randn(2 * max_distance + 1, num_units) * 0.1
)
# 衰减权重
self.decay_weight = nn.Parameter(torch.ones(num_units) * decay_rate)
def _get_relative_position(self, seq_len):
"""生成相对位置索引"""
position = torch.arange(seq_len)
relative = position.unsqueeze(1) - position.unsqueeze(0) # [seq_len, seq_len]
return relative
def _clip_relative_position(self, relative):
"""截断并映射相对位置"""
# 截断到 [-max_distance, max_distance]
clipped = torch.clamp(relative, -self.max_distance, self.max_distance)
# 转换为非负索引
return clipped + self.max_distance
def forward(self, query, key):
"""
计算相对位置偏置
Args:
query: [batch_size, num_heads, seq_len, d_k]
key: [batch_size, num_heads, seq_len, d_k]
Returns:
相对位置偏置 [seq_len, seq_len, num_units]
"""
seq_len = query.size(2)
relative = self._get_relative_position(seq_len)
clipped_idx = self._clip_relative_position(relative)
# 查表获取嵌入
r = self.embedding[clipped_idx] # [seq_len, seq_len, num_units]
# 对远距离应用衰减
mask = (relative.abs() > self.max_distance).float()
decay = self.decay_weight ** (relative.abs() - self.max_distance).clamp(min=0)
decay = decay.unsqueeze(0).unsqueeze(-1) # [1, seq_len, 1]
r = r * (1 - mask).unsqueeze(-1) + r * mask * decay
return r注意力中的使用方式:
def shaw_attention(query, key, value, r, d_k):
"""
Shaw相对位置感知的注意力计算
相对位置嵌入通过分解因子与注意力分数交互
"""
# 标准注意力分数
score = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 相对位置因子
# query: [B, H, N, D], r: [N, N, num_units]
# 使用二次型形式: query * R * key
Qr = torch.einsum('bhnd,ndp->bhndp', query, r) # [B, H, N, N, num_units]
Qr = Qr.sum(dim=-1) # [B, H, N, N]
# 组合
score = score + Qr / math.sqrt(d_k)
attn = torch.softmax(score, dim=-1)
return torch.matmul(attn, value)3.3 XLNet的循环相对位置编码
XLNet提出循环相对位置编码,将相对位置与循环机制结合,支持建模超长距离依赖。5
核心思想:将位置表示分解为内容相关和位置相关的偏置:
实现代码:
import torch
import torch.nn as nn
import math
class XLNetRelativePosition(nn.Module):
"""
XLNet循环相对位置编码
特点:
- 分离内容和位置贡献
- 支持循环机制建模长距离
- 可处理超长序列
"""
def __init__(self, d_model, d_head, max_position=512, bidirectional=True):
super().__init__()
self.d_head = d_head
self.max_position = max_position
self.bidirectional = bidirectional
# 内容-位置偏置
self.r = nn.Parameter(torch.randn(max_position, d_head))
nn.init.normal_(self.r, std=0.02)
# 位置-内容偏置(可学习)
self.r_bias = nn.Parameter(torch.zeros(max_position, d_head))
def _get_positional_bias(self, seq_len):
"""生成相对位置矩阵"""
# 生成位置索引
position_ids = torch.arange(seq_len, device=self.r.device)
if not self.bidirectional:
position_i = torch.arange(seq_len, device=self.r.device).unsqueeze(1)
position_j = torch.arange(seq_len, device=self.r.device).unsqueeze(0)
relative_position = position_j - position_i # 只考虑向后看
else:
relative_position = position_ids.unsqueeze(1) - position_ids.unsqueeze(0)
# 截断
relative_position = torch.clamp(relative_position, 0, self.max_position - 1)
return relative_position
def forward(self, query, key, content_bias=None):
"""
Args:
query: [batch_size, num_heads, seq_len, d_head] (content query)
key: [batch_size, num_heads, seq_len, d_head] (content key)
content_bias: 额外的content偏置
Returns:
注意力分数的偏置部分
"""
seq_len = query.size(2)
rel_pos = self._get_positional_bias(seq_len)
# 位置-内容偏置
# r_bias: [max_position, d_head]
# r: [max_position, d_head]
r_bias = self.r_bias[:seq_len] # [seq_len, d_head]
r_bias = r_bias.unsqueeze(1) # [1, 1, seq_len, d_head]
# 内容-位置偏置
r = self.r[rel_pos] # [seq_len, seq_len, d_head]
r = r.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len, d_head]
# 内容-内容 + 内容-位置 + 位置-内容 + 位置-位置
# 简化为:r_bias (位置-内容) + r (内容-位置)
bias = torch.einsum('bhqd,nd->bhqn', query, r.squeeze(0).squeeze(0))
bias = bias + torch.einsum('bhld,nd->bhln', key, r_bias.squeeze(0).squeeze(0))
return bias
class XLNetRelativeAttention(nn.Module):
"""
XLNet相对注意力机制
"""
def __init__(self, d_model, num_heads, max_position=512):
super().__init__()
self.d_head = d_model // num_heads
self.num_heads = num_heads
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
self.o = nn.Linear(d_model, d_model)
self.rel_pos = XLNetRelativePosition(d_model, self.d_head, max_position)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# QKV投影
q = self.q(x).view(batch_size, seq_len, self.num_heads, self.d_head)
k = self.k(x).view(batch_size, seq_len, self.num_heads, self.d_head)
v = self.v(x).view(batch_size, seq_len, self.num_heads, self.d_head)
# 调整维度
q = q.transpose(1, 2) # [B, H, N, D]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# 内容-内容注意力
score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
# 相对位置偏置
rel_bias = self.rel_pos(q, k)
score = score + rel_bias
if mask is not None:
score = score.masked_fill(mask == 0, -1e9)
attn = torch.softmax(score, dim=-1)
output = torch.matmul(attn, v)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.o(output)XLNet循环机制:XLNet的关键创新在于使用循环来实现对长序列的建模:
def xlnet_layer_forward(x, prev_kv, memory_length=0):
"""
XLNet层的循环前向传播
利用前一时刻的KV状态实现跨segment的信息传递
"""
seq_len = x.size(1)
# 如果有memory,拼接
if prev_kv is not None:
x = torch.cat([prev_kv, x], dim=1)
# 当前层的计算
output = relative_attention(x, x, x)
# 保存KV用于下一时刻
new_kv = output[:, -seq_len:]
return output[:, :seq_len], new_kv4. RoPE旋转位置编码
RoPE(Rotary Position Embedding,旋转位置编码)由Su等人提出,通过旋转矩阵编码位置信息,是当前LLM最广泛使用的位置编码方案。6
4.1 旋转矩阵的几何意义
RoPE的核心思想是对query和key向量施加旋转变换,使其内积隐式编码相对位置。
二维旋转矩阵:
几何解释:向量 左乘旋转矩阵 等价于将向量逆时针旋转 角度,而不改变其长度:
4.2 数学推导:如何实现相对位置
4.2.1 目标
我们希望找到一种编码方式,使得旋转后的query和key的内积只依赖于相对位置 :
4.2.2 二维情况的推导
设 ,应用旋转:
类似地:
计算内积:
关键性质:旋转矩阵的正交性
因此:
这表明内积仅依赖相对位置 。
4.2.3 高维扩展
对于 维向量,将维度两两分组,每组独立应用二维旋转:
其中旋转角度为:
4.3 实现细节与代码
import torch
import torch.nn as nn
import math
class RotaryPositionEmbedding(nn.Module):
"""
旋转位置编码(RoPE)
通过旋转矩阵编码位置,使注意力内积隐式包含相对位置信息
特点:
- 无额外参数
- 支持任意长度外推(理论上)
- 可与GQA等注意力变体兼容
- 训练和推理开销小
"""
def __init__(self, dim, base=10000):
super().__init__()
self.dim = dim
self.base = base
# 逆频率:1 / base^{2i/dim}
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, max_seq_len, device=None):
"""
预计算所有位置的旋转角
Args:
max_seq_len: 最大序列长度
device: 设备
Returns:
cos: 旋转角的余弦值
sin: 旋转角的正弦值
"""
if device is None:
device = self.inv_freq.device
# 位置索引:[0, 1, 2, ..., max_seq_len-1]
seq = torch.arange(max_seq_len, device=device)
# 外积得到每对(位置, 维度)的频率
# seq: [max_seq_len], inv_freq: [dim/2]
# freqs: [max_seq_len, dim/2]
freqs = torch.outer(seq, self.inv_freq)
# 拼接余弦和正弦:[max_seq_len, dim]
emb = torch.cat([freqs, freqs], dim=-1)
return emb.cos(), emb.sin()
@staticmethod
def rotate_half(x):
"""
旋转半个维度
将x的后半部分取负并前置
等价于:[-x_{d/2:}, x_{:d/2}]
"""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat([-x2, x1], dim=-1)
def apply_rotary(self, q, k, cos, sin, position_ids=None):
"""
应用旋转到query和key
Args:
q: [batch_size, num_heads, seq_len, head_dim]
k: [batch_size, num_heads, seq_len, head_dim]
cos: [seq_len, head_dim] 或 [batch_size, seq_len, head_dim]
sin: [seq_len, head_dim] 或 [batch_size, seq_len, head_dim]
position_ids: 位置索引(用于非顺序输入)
Returns:
旋转后的q和k
"""
# 确保cos/sin维度正确
if cos.dim() == 2:
# [seq_len, dim] -> [1, 1, seq_len, dim]
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
# 应用旋转公式:x' = x * cos + rotate_half(x) * sin
q_embed = q * cos + self.rotate_half(q) * sin
k_embed = k * cos + self.rotate_half(k) * sin
return q_embed, k_embed
class RoPEMultiHeadAttention(nn.Module):
"""
带RoPE的多头注意力
"""
def __init__(self, d_model, num_heads, base=10000):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.rope = RotaryPositionEmbedding(self.d_k, base)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# QKV投影
q = self.w_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
k = self.w_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
v = self.w_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# 应用RoPE
cos, sin = self.rope(seq_len, device=x.device)
q, k = self.rope.apply_rotary(q, k, cos, sin)
# 注意力计算
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.w_o(output)4.4 2D RoPE vs 高维RoPE
4.4.1 2D RoPE(复数形式)
2D RoPE使用复数乘法实现旋转:
这等价于将query视为复数,与旋转因子的乘积。
4.4.2 高维RoPE
对于 维RoPE,按对角块分组:
维度分组策略:
| 分组方式 | 描述 | 适用场景 |
|---|---|---|
| 相邻配对 | 标准RoPE | |
| 间隔配对 | LLaMA采用 | |
| 全连接 | 所有维度互相作用 | 计算开销大 |
LLaMA的RoPE实现:
class LlamaRotaryEmbedding(nn.Module):
"""
LLaMA风格的RoPE
使用间隔配对策略,与GPT-NeoX相同
"""
def __init__(self, dim, max_position=2048, base=10000):
super().__init__()
self.dim = dim
self.base = base
self.max_position = max_position
# 计算逆频率
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# 预计算(可选)
self._compute_cos_sin(max_position)
def _compute_cos_sin(self, seq_len):
"""预计算所有位置的cos和sin"""
position = torch.arange(seq_len)
freqs = torch.outer(position, self.inv_freq)
self.register_buffer('cos_cached', freqs.cos())
self.register_buffer('sin_cached', freqs.sin())
def forward(self, x, position_ids=None):
"""
Args:
x: [batch_size, num_heads, seq_len, head_dim]
"""
if position_ids is None:
seq_len = x.shape[2]
position_ids = torch.arange(seq_len, device=x.device)
# 获取频率
freqs = torch.outer(position_ids, self.inv_freq)
# 拼接
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos, sin5. ALiBi线性偏置
ALiBi(Attention with Linear Biases)由Press等人提出,是一种无需学习的自然外推位置编码方案。7
5.1 无需学习的方案
ALiBi的核心设计原则是零额外参数,通过简单的线性函数将位置信息注入注意力分数。
基础公式:
其中 是衰减系数,控制位置距离对注意力分数的影响。
5.2 数学推导:线性衰减
5.2.1 多头ALiBi
对于多头注意力,每个头可有独立的衰减系数:
衰减系数设置(论文推荐):
这种设计使得不同头捕获不同尺度的位置关系:低索引头对位置更敏感,高索引头对位置更迟钝。
5.2.2 衰减特性分析
| 特性 | 分析 |
|---|---|
| 局部注意力 | 短距离token获得较高注意力分数 |
| 远程衰减 | 距离越大,注意力分数越低 |
| 线性衰减 | 衰减速度快于指数衰减 |
| 方向无关 | 使用绝对距离,忽略方向 |
5.3 与RoPE的对比分析
| 维度 | RoPE | ALiBi |
|---|---|---|
| 编码方式 | 旋转Q/K向量 | 注意力分数加偏置 |
| 参数量 | 极小(仅频率参数) | 零参数 |
| 实现复杂度 | 中等 | 极简 |
| 相对位置 | 隐式编码 | 显式编码 |
| 外推能力 | 需NTK-Scaling | 天然外推 |
| 与注意力融合 | 旋转操作 | 加法偏置 |
| GQA兼容性 | 原生支持 | 需修改偏置生成 |
| Flash Attention | 可集成 | 可集成 |
外推能力对比:
| 场景 | RoPE(标准) | RoPE+NTK | ALiBi |
|---|---|---|---|
| 训练: 1K → 测试: 2K | 良好 | 优秀 | 优秀 |
| 训练: 1K → 测试: 8K | 退化 | 良好 | 良好 |
| 训练: 4K → 测试: 32K | 严重退化 | 可行 | 可行 |
代码实现对比:
# RoPE注意力
class RoPEAttention(nn.Module):
def forward(self, q, k, v):
# 1. QKV投影
q, k = self.project_q(q), self.project_k(k)
# 2. 应用RoPE旋转
q = self.rope.rotate(q)
k = self.rope.rotate(k)
# 3. 注意力计算
return self.attention(q, k, v)
# ALiBi注意力
class ALiBiAttention(nn.Module):
def __init__(self, num_heads):
super().__init__()
# 预计算衰减斜率
slopes = torch.pow(2, -8.0 / torch.arange(1, num_heads + 1))
self.register_buffer('slopes', slopes)
def forward(self, q, k, v):
# 1. QKV投影
q, k, v = self.project(q), self.project(k), self.project(v)
# 2. 计算注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d)
# 3. 添加ALiBi偏置
scores = scores + self._get_alibi_bias(q.size(2))
# 4. Softmax
return self.attention(scores, v)
def _get_alibi_bias(self, seq_len):
# 生成距离矩阵
positions = torch.arange(seq_len)
distance = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1))
# 应用斜率衰减
return -distance.unsqueeze(0) * self.slopes.view(-1, 1, 1)6. 位置编码外推技术
位置编码外推(Position Encoding Extrapolation)指模型在训练时见过长度为 的序列,但需要在长度为 的序列上进行推理。
6.1 位置编码外推的挑战
6.1.1 外推失败的本质
以可学习位置编码为例:
- 模型在 范围内学习位置嵌入
- 外推到 时,无法获得对应嵌入
- 即使使用Sinusoidal或RoPE,远距离位置也可能因频率混淆导致性能下降
6.1.2 外推问题的数学刻画
问题:设 为位置编码函数, 为从位置编码到输出的映射。若:
但当 时,模型无法正确推广。
6.1.3 失败症状
| 症状 | 描述 |
|---|---|
| 注意力分散 | 远距离token注意力异常增高或降低 |
| 困惑度激增 | 长序列的困惑度显著上升 |
| 重复生成 | 模型陷入循环生成 |
| 位置混淆 | 模型无法区分不同远距离位置 |
6.2 插值方法:NTK-aware scaling
NTK-aware scaling是一种针对RoPE的插值策略,由Ilyas等人提出。8
6.2.1 核心思想
NTK-aware scaling不直接插值位置编码,而是调整RoPE的频率基底,保持高频不变,仅压缩低频。
6.2.2 数学推导
问题:标准RoPE在高维度使用高频(短周期),低维度使用低频(长周期)。直接插值会破坏高频维度。
解决思路:调整基参数 ,使得:
其中 是缩放因子。
NTK-aware公式:
其中 是上下文扩展倍数, 是原始基参数。
def ntk_scaled_rope(base, dim, scale):
"""
NTK-aware RoPE缩放
Args:
base: 原始基参数(通常10000)
dim: 模型维度
scale: 缩放因子(目标长度/原始长度)
"""
# 调整后的基参数
new_base = base * (scale ** (dim / (dim - 2)))
return new_base
def apply_ntk_scaling(model, scale):
"""
对模型应用NTK缩放
"""
for module in model.modules():
if hasattr(module, 'rope'):
old_base = module.rope.base
new_base = ntk_scaled_rope(old_base, module.rope.dim, scale)
module.rope.base = new_base
# 更新inv_freq
module.rope.inv_freq = 1.0 / (new_base ** (torch.arange(0, dim, 2).float() / dim))6.2.3 NTK-aware vs 简单插值
| 方法 | 实现 | 效果 |
|---|---|---|
| 位置插值(PI) | 直接缩放位置索引 | 破坏高频维度 |
| 频率插值(FI) | 缩放所有频率 | 高频信息丢失 |
| NTK-aware | 仅压缩低频 | 保留高频信息 |
6.3 插值方法:YaRN
YaRN(Yet another RoPE extensioN)是NTK-aware scaling的改进版本,进一步提升了外推能力。9
6.3.1 核心改进
YaRN在NTK-aware基础上引入两项改进:
- 温度因子:对注意力分数应用温度缩放
- 拉伸因子:对低频维度应用更激进的拉伸
6.3.2 数学公式
拉伸因子:
温度缩放:
class YaRNPositionEmbedding(nn.Module):
"""
YaRN旋转位置编码
结合NTK-aware scaling和温度缩放
"""
def __init__(self, dim, base=10000, max_position=2048, scale=1.0, extrapolation_factor=1.0):
super().__init__()
self.dim = dim
self.base = base
self.scale = scale
# YaRN特定的基参数调整
self.extrapolation_factor = extrapolation_factor
# 计算调整后的基参数
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def _yarn_invariant_freq(self, scale):
"""计算YaRN不变的频率调整"""
dim = self.dim
# 计算拉伸因子
floor_scale = torch.floor(scale)
mscale = 0.1 * floor_scale * math.sqrt(1 + 1 / (floor_scale ** 2)) - 1
# 温度因子
mscale = 1 + mscale * math.log(scale) / (math.log(scale) + 1)
return mscale
def forward(self, x):
seq_len = x.shape[1]
position = torch.arange(seq_len, device=x.device)
# 应用YaRN调整
if self.scale > 1.0:
mscale = self._yarn_invariant_freq(self.scale)
position = position * self.scale / mscale
freqs = torch.outer(position, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
return emb.cos(), emb.sin()6.3.3 实验效果
| 方法 | 困惑度变化 | 长度扩展倍数 |
|---|---|---|
| 标准RoPE | 急剧上升 | 1× |
| 位置插值 | 轻微上升 | 2-4× |
| NTK-aware | 稳定 | 2-8× |
| YaRN | 最优稳定 | 最高32× |
6.4 其他外推技术
6.4.1 逐步扩展(逐步微调)
方法:分阶段训练,逐步扩展上下文长度。
def progressive_training_schedule(model, target_len, steps):
"""
渐进式训练计划
逐步增加上下文长度并微调
"""
current_len = 512
scale_factor = (target_len / current_len) ** (1 / steps)
for step in range(steps):
# 扩展长度
current_len = int(current_len * scale_factor)
# 调整位置编码
adjust_position_encoding(model, current_len)
# 微调阶段
fine_tune(model, train_tokens=current_len * batch_size)6.4.2 位置丢弃(Position Dropout)
方法:训练时随机丢弃部分位置,增强模型对位置变化的鲁棒性。
class PositionDropout(nn.Module):
def __init__(self, p=0.1):
super().__init__()
self.p = p
def forward(self, x, position_ids=None):
if self.training and torch.rand(1) < self.p:
# 随机跳位
skip = torch.randint(1, 10, (1,)).item()
x = x[:, skip:]
if position_ids is not None:
position_ids = position_ids[:, skip:]
return x, position_ids6.4.3 随机位置编码
方法:训练时使用随机采样的位置索引,增强位置编码的泛化能力。
class RandomPositionalEncoding(nn.Module):
def forward(self, x):
seq_len = x.shape[1]
# 随机重排位置(仅训练时)
if self.training:
position_ids = torch.randperm(seq_len)
else:
position_ids = torch.arange(seq_len)
return x, position_ids7. 混合与最新进展
7.1 CoPE(Conditional Position Encoding)
CoPE由局外人等人提出,是一种根据内容动态计算位置的编码方案。10
7.1.1 核心思想
传统位置编码使用固定的整数位置,CoPE根据token的内容动态决定”位置”——实际上是基于门控值的累积。
7.1.2 数学推导
门控函数:
累积位置:
注意力分数:
其中 是关于累积位置的函数(可学习)。
7.1.3 实现代码
import torch
import torch.nn as nn
import math
class ConditionalPositionEncoding(nn.Module):
"""
条件位置编码(CoPE)
根据内容动态决定位置,而非使用固定整数位置
"""
def __init__(self, d_model, max_heads=32):
super().__init__()
self.max_heads = max_heads
# 门控网络
self.gate_proj = nn.Linear(d_model, 1)
# 位置函数(可学习)
self.pos_func = nn.Sequential(
nn.Linear(1, d_model),
nn.ReLU(),
nn.Linear(d_model, d_model)
)
def forward(self, q, k, v):
"""
Args:
q: [batch_size, num_heads, seq_len, d_k]
k: [batch_size, num_heads, seq_len, d_k]
v: [batch_size, num_heads, seq_len, d_k]
Returns:
带CoPE的注意力输出
"""
batch_size, num_heads, seq_len, d_k = q.shape
# 1. 计算门控值 g_{ij}
# 将q和k投影到统一空间
q_gate = q.view(batch_size, num_heads, seq_len, 1, d_k)
k_gate = k.view(batch_size, num_heads, 1, seq_len, d_k)
# 门控注意力
gate = torch.sigmoid(self.gate_proj(torch.tanh(q_gate + k_gate)))
# gate: [batch_size, num_heads, seq_len, seq_len, 1]
gate = gate.squeeze(-1) # [batch_size, num_heads, seq_len, seq_len]
# 2. 累积得到位置 p_{ij}
# 前向累积:p_{ij} = sum_t(g_{it} for t <= j)
positions = torch.cumsum(gate, dim=-1) # [batch_size, num_heads, seq_len, seq_len]
# 3. 标准化位置到[0, max_heads]
positions = positions / (positions[:, :, :, -1:] + 1e-6) * self.max_heads
positions = positions.clamp(0, self.max_heads)
# 4. 通过位置函数获取偏置
positions_flat = positions.view(batch_size, num_heads, -1, 1) # [B, H, N^2, 1]
pos_bias = self.pos_func(positions_flat) # [B, H, N^2, d_k]
# 5. 标准注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
# 6. 添加位置偏置
pos_bias = pos_bias.view(batch_size, num_heads, seq_len, seq_len)
scores = scores + pos_bias
# 7. 注意力计算
attn = torch.softmax(scores, dim=-1)
output = torch.matmul(attn, v)
return output7.1.4 CoPE vs 其他编码
| 特性 | 绝对PE | 相对PE | RoPE | ALiBi | CoPE |
|---|---|---|---|---|---|
| 位置定义 | 固定整数 | 固定整数 | 固定整数 | 固定整数 | 动态 |
| 内容感知 | ❌ | ❌ | ❌ | ❌ | ✅ |
| 参数量 | 可变 | ||||
| 实现复杂度 | 低 | 中 | 中 | 低 | 高 |
7.2 位置编码的选择建议
根据不同应用场景,以下是位置编码的选择指南:
7.2.1 按任务类型选择
| 任务类型 | 推荐方案 | 理由 |
|---|---|---|
| 短文本分类 | 可学习PE | 参数足够,训练简单 |
| 标准语言建模 | RoPE | 成熟稳定,主流选择 |
| 代码生成 | RoPE + NTK | 需要长距离理解 |
| 对话系统 | ALiBi或RoPE | 变长对话需外推 |
| 多模态(视频/音频) | 相对PE或CoPE | 时间序列需灵活位置 |
| 超长上下文 | ALiBi + NTK | 极端长度需双重保障 |
7.2.2 按模型规模选择
| 模型规模 | 推荐方案 | 备注 |
|---|---|---|
| 小型(<1B) | ALiBi或可学习PE | 参数效率优先 |
| 中型(1B-70B) | RoPE | 平衡效果与效率 |
| 超大型(>70B) | RoPE + GQA + NTK | 高效推理+长上下文 |
7.2.3 按推理场景选择
| 推理场景 | 推荐方案 |
|---|---|
| 固定长度推理 | 任选 |
| 批处理变长 | RoPE或ALiBi |
| 流式推理 | RoPE |
| 超长序列 | ALiBi或NTK-Scaled RoPE |
7.3 未来研究方向
7.3.1 自适应位置编码
根据输入内容自动调整位置编码策略:
class AdaptivePositionalEncoding(nn.Module):
"""
自适应位置编码
根据序列特点选择最优编码方式
"""
def __init__(self, d_model):
super().__init__()
self.rope = RotaryPositionEmbedding(d_model)
self.alibi_slopes = nn.Parameter(torch.ones(8)) # 可学习的ALiBi斜率
self.selector = nn.Linear(d_model, 3) # 选择器
def forward(self, x, mode=None):
logits = self.selector(x.mean(dim=1)) # [batch_size, 3]
if mode == 'rope':
return self.rope(x)
elif mode == 'alibi':
return self._apply_alibi(x)
else:
# 软选择
probs = F.softmax(logits, dim=-1)
rope_out = self.rope(x)
alibi_out = self._apply_alibi(x)
return probs[:, 0:1] * rope_out + probs[:, 1:2] * alibi_out7.3.2 任务导向位置编码
针对特定任务(如检索、推理)设计专用位置编码:
class TaskAwarePositionalEncoding(nn.Module):
"""
任务感知位置编码
根据任务类型调整位置表示
"""
def __init__(self, d_model, num_tasks):
super().__init__()
self.task_embeddings = nn.Embedding(num_tasks, d_model)
self.rope = RotaryPositionEmbedding(d_model)
def forward(self, x, task_id):
# 获取任务嵌入
task_emb = self.task_embeddings(task_id)
# 任务感知的RoPE
modified_freq = self.rope.inv_freq * (1 + task_emb)
# ... 应用修改后的旋转7.3.3 可组合位置编码
探索不同位置编码方案的组合优势:
| 组合方案 | 预期优势 |
|---|---|
| RoPE + ALiBi | 精确相对位置 + 天然外推 |
| RoPE + CoPE | 旋转不变性 + 内容感知 |
| ALiBi + 层次PE | 线性衰减 + 多尺度 |
class HybridPositionalEncoding(nn.Module):
"""
混合位置编码
结合RoPE和ALiBi的优势
"""
def __init__(self, d_model, num_heads):
super().__init__()
self.rope = RotaryPositionEmbedding(d_model)
self.alibi_slopes = self._get_alibi_slopes(num_heads)
def _get_alibi_slopes(self, num_heads):
slopes = torch.pow(2, -8.0 / torch.arange(1, num_heads + 1))
return nn.Parameter(slopes)
def forward(self, q, k, v):
# 1. 应用RoPE
q, k = self.rope.rotate(q), self.rope.rotate(k)
# 2. 计算ALiBi偏置
seq_len = q.shape[2]
positions = torch.arange(seq_len, device=q.device)
distance = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1))
alibi_bias = -distance.unsqueeze(0) * self.alibi_slopes.view(-1, 1, 1)
# 3. 注意力计算
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
scores = scores + alibi_bias
return torch.matmul(torch.softmax(scores, -1), v)8. 参考资料
相关词条:位置编码的几何理论,ALiBi位置编码深度解析,Transformer与注意力机制,稀疏注意力与长度外推,LLM架构族对比,FlashAttention深度解析
Footnotes
-
Vaswani et al., “Attention Is All You Need”, NeurIPS 2017 ↩
-
Devlin et al., “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding”, NAACL 2019 ↩
-
Raffel et al., “Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer”, JMLR 2020 ↩
-
Shaw et al., “Self-Attention with Relative Position Representations”, NAACL 2018 ↩
-
Yang et al., “XLNet: Generalized Autoregressive Pretraining for Language Understanding”, NeurIPS 2019 ↩
-
Su et al., “RoFormer: Enhanced Transformer with Rotary Position Embedding”, arXiv:2104.09864, 2021 ↩
-
Press et al., “Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation”, ICLR 2022 ↩
-
Ilyas et al., “Rotary Position Embeddings for Transformer”, arXiv, 2023 ↩
-
Peng et al., “YaRN: Efficient Context Window Extension of LLMs”, arXiv, 2023 ↩
-
局外人等, “Conditional Positional Encoding”, ICLR 2022 ↩