Transformer位置编码完整指南

位置编码(Positional Encoding, PE)是Transformer架构中不可或缺的核心组件,负责将序列中token的顺序信息注入模型。与RNN和CNN天然具有位置感知能力不同,Transformer的自注意力机制具有位置不变性——即输入token的排列顺序不影响注意力计算结果。1 本指南系统梳理位置编码的技术演进、数学推导与实践指南。

1. 位置编码基础

1.1 为什么需要位置编码

Transformer的核心是自注意力(Self-Attention)机制,其计算可表示为:

其中 分别由输入嵌入通过线性变换得到。关键问题在于:输入嵌入本身不包含位置信息。如果将序列 打乱为 ,自注意力计算结果完全相同,因为矩阵乘法具有置换不变性:

这意味着Transformer在没有位置编码的情况下,是一个”词袋模型”(Bag-of-Words),无法区分”狗咬人”和”人咬狗”的语义差异。

1.2 自注意力的位置不变性

自注意力的置换不变性来源于其计算本质:注意力分数仅依赖于query和key向量的内积,而内积运算对输入顺序不敏感。

形式化定义:设 为任意置换,对于注意力输出有:

其中 为置换矩阵。这表明注意力输出会按相同方式置换,但计算本身与顺序无关。

位置编码的解决思路:向输入嵌入中添加位置信息 ,使得:

这样打乱输入顺序会导致 ,从而模型可以感知位置差异。

1.3 绝对vs相对位置编码

根据编码方式的不同,位置编码可分为两大类:

绝对位置编码(Absolute Positional Encoding)

编码每个位置的绝对索引 ,生成独立的嵌入向量

代表方法:可学习位置编码、Sinusoidal位置编码

特点

  • 每个位置有唯一标识
  • 参数量随序列长度线性增长(可学习版本)
  • 推理时可外推到更长序列(Sinusoidal版本)

相对位置编码(Relative Positional Encoding)

编码token之间的距离 ,而非绝对位置:

其中 是关于相对位置 的函数。

代表方法:T5 Relative Position Bias、Shaw’s Relative Position Embedding、RoPE、ALiBi

特点

  • 自然编码位置关系
  • 对序列长度变换更鲁棒
  • 更适合长度外推任务
维度绝对位置编码相对位置编码
编码对象绝对索引 相对距离
参数效率可学习版本需 参数通常
外推能力Sinusoidal可外推天然支持外推
位置关系建模间接直接
计算复杂度较低可较高

2. 绝对位置编码

2.1 可学习位置编码

可学习位置编码(Learned Positional Embedding)是最直觉的方案,将位置视为可训练的嵌入向量:

其中 是可学习参数矩阵, 是预设的最大序列长度。

实现代码

import torch
import torch.nn as nn
 
class LearnedPositionalEmbedding(nn.Module):
    """
    可学习位置编码
    
    优点:简单直观,可端到端优化
    缺点:无法外推到训练长度之外
    """
    def __init__(self, max_len, d_model):
        super().__init__()
        # 位置嵌入矩阵:[max_len, d_model]
        self.pe = nn.Embedding(max_len, d_model)
    
    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        Returns:
            带位置编码的嵌入
        """
        batch_size, seq_len, _ = x.shape
        # 生成位置索引 [0, 1, 2, ..., seq_len-1]
        position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0)
        position_ids = position_ids.expand(batch_size, -1)
        # 查询位置嵌入
        position_embeddings = self.pe(position_ids)
        return x + position_embeddings

特点分析

特性评价
表达能力理论上可学习任意位置映射
参数量
外推能力❌ 无法外推到 之外
计算效率首次调用需读取嵌入,后续高效
训练稳定性依赖良好初始化

应用场景:适用于序列长度固定的场景,如分类任务的BERT。

2.2 Sinusoidal位置编码

Vaswani等人提出的Sinusoidal位置编码通过正弦余弦函数生成位置向量,无需学习参数。2

数学公式:对于位置 和维度

可统一写作:

实现代码

import torch
import math
 
class SinusoidalPositionalEmbedding(nn.Module):
    """
    Sinusoidal位置编码
    
    特点:
    - 无需学习参数
    - 可外推到任意长度
    - 位置向量具有周期性和衰减特性
    """
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.d_model = d_model
        
        # 预计算编码矩阵(可选,节省运行时计算)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        # 计算除数项:10000^{2j/d}
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            (-math.log(10000.0) / d_model)
        )
        
        # 填充偶数维度(sin)
        pe[:, 0::2] = torch.sin(position * div_term)
        # 填充奇数维度(cos)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # 注册为buffer(不参与梯度计算,但会随模型保存)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        Returns:
            带位置编码的嵌入
        """
        seq_len = x.size(1)
        # 取前seq_len个位置编码
        return x + self.pe[:seq_len]
    
    @staticmethod
    def encode_position(pos, d_model):
        """
        动态计算单个位置的位置编码(适用于变长序列)
        
        Args:
            pos: 位置标量或张量
            d_model: 模型维度
        Returns:
            位置编码向量
        """
        pe = torch.zeros(pos.shape[0] if hasattr(pos, 'shape') else 1, d_model)
        pe[:, 0::2] = torch.sin(pos / torch.pow(10000, 2 * torch.arange(0, d_model, 2) / d_model))
        pe[:, 1::2] = torch.cos(pos / torch.pow(10000, 2 * torch.arange(1, d_model, 2) / d_model))
        return pe

2.3 Sinusoidal编码的性质

Sinusoidal位置编码具有若干优良性质,使其成为Transformer的默认选择。

2.3.1 周期性与频率特性

位置编码的不同维度对应不同的频率:

维度范围周期特性
低维度(长周期(编码全局位置
高维度(短周期(编码精细位置

这种多尺度设计使得模型可以从不同维度捕获从粗到细的位置信息。

2.3.2 相对位置编码性质

关键定理:Sinusoidal位置编码的任意两个位置的点积仅依赖于它们的相对距离。

证明:对于位置 ,令 ,则:

利用三角恒等式:

得:

这表明点积只依赖 ,即隐式编码了相对位置信息

2.3.3 外推能力

Sinusoidal编码理论上可外推到任意长度,因为:

  1. 数学连续性:正弦函数是连续且周期性的
  2. 无参数依赖:不依赖预定义的最大长度
  3. 多尺度覆盖:不同频率组合可表示任意长距离

然而实践中,外推效果通常不如专门设计的外推方法(如ALiBi)。

2.3.4 与词嵌入的融合

Sinusoidal编码与词嵌入的融合方式为加性融合

这种方式简单但存在耦合问题:位置信息和语义信息在加法操作中混合,可能干扰各自的学习。


3. 相对位置编码

相对位置编码直接编码token之间的距离,天然更适合建模位置关系,近年来成为主流方案。

3.1 T5 Relative Position Bias

T5模型提出的相对位置偏置是一种可学习的软偏置,直接加在注意力分数上。3

数学公式

其中偏置项定义为:

将相对距离截断到 范围, 是可学习的偏置向量。

实现代码

import torch
import torch.nn as nn
 
class T5RelativePositionBias(nn.Module):
    """
    T5相对位置偏置
    
    特点:
    - 可学习的离散偏置表
    - 截断机制限制参数增长
    - 直接作用于注意力分数
    """
    def __init__(self, d_model, num_heads, max_distance=128):
        super().__init__()
        self.num_heads = num_heads
        self.max_distance = max_distance
        
        # 偏置表:[-max_distance, ..., 0, ..., max_distance]
        self.relative_attention_bias = nn.Embedding(2 * max_distance + 1, num_heads)
    
    def _get_relative_positions(self, seq_len):
        """生成相对位置矩阵"""
        # 生成绝对位置索引
        position_ids = torch.arange(seq_len, device=self.relative_attention_bias.weight.device)
        # 计算相对位置:position[i] - position[j]
        relative_position = position_ids.unsqueeze(0) - position_ids.unsqueeze(1)
        # 截断到 [-max_distance, max_distance]
        relative_position = torch.clamp(
            relative_position,
            -self.max_distance,
            self.max_distance
        )
        # 平移使索引从0开始
        relative_position = relative_position + self.max_distance
        return relative_position
    
    def forward(self, query, key):
        """
        Args:
            query: [batch_size, num_heads, seq_len_q, d_k]
            key: [batch_size, num_heads, seq_len_k, d_k]
        Returns:
            相对位置偏置矩阵
        """
        seq_len_q = query.size(2)
        seq_len_k = key.size(2)
        
        relative_position = self._get_relative_positions(max(seq_len_q, seq_len_k))
        # 获取偏置 [seq_len_q, seq_len_k, num_heads]
        bias = self.relative_attention_bias(relative_position)
        # 调整维度顺序:[seq_len_q, seq_len_k, num_heads] -> [num_heads, seq_len_q, seq_len_k]
        bias = bias.permute(2, 0, 1)
        
        return bias

特点分析

维度分析
参数量 为最大距离, 为头数
外推能力有限,需预定义
表达灵活度离散偏置,可学习任意映射
计算开销轻微,需查表

3.2 Shaw’s Relative Position Embedding

Shaw等人提出的相对位置嵌入通过连续函数建模相对位置关系。4

核心思想:相对位置 被映射到连续向量空间:

其中嵌入函数设计为:

其中 控制衰减率。

实现代码

import torch
import torch.nn as nn
import math
 
class ShawRelativePositionEmbedding(nn.Module):
    """
    Shaw相对位置嵌入
    
    特点:
    - 有限窗口内精确编码
    - 无限距离外指数衰减
    - 支持相对位置感知注意力
    """
    def __init__(self, d_model, max_distance=16, num_units=8, decay_rate=0.99):
        super().__init__()
        self.max_distance = max_distance
        self.num_units = num_units  # 每个位置的嵌入维度
        self.decay_rate = decay_rate
        
        # 嵌入维度为 num_units * 2 + 1(考虑正负位置)
        self.embedding = nn.Parameter(
            torch.randn(2 * max_distance + 1, num_units) * 0.1
        )
        # 衰减权重
        self.decay_weight = nn.Parameter(torch.ones(num_units) * decay_rate)
    
    def _get_relative_position(self, seq_len):
        """生成相对位置索引"""
        position = torch.arange(seq_len)
        relative = position.unsqueeze(1) - position.unsqueeze(0)  # [seq_len, seq_len]
        return relative
    
    def _clip_relative_position(self, relative):
        """截断并映射相对位置"""
        # 截断到 [-max_distance, max_distance]
        clipped = torch.clamp(relative, -self.max_distance, self.max_distance)
        # 转换为非负索引
        return clipped + self.max_distance
    
    def forward(self, query, key):
        """
        计算相对位置偏置
        
        Args:
            query: [batch_size, num_heads, seq_len, d_k]
            key: [batch_size, num_heads, seq_len, d_k]
        Returns:
            相对位置偏置 [seq_len, seq_len, num_units]
        """
        seq_len = query.size(2)
        relative = self._get_relative_position(seq_len)
        clipped_idx = self._clip_relative_position(relative)
        
        # 查表获取嵌入
        r = self.embedding[clipped_idx]  # [seq_len, seq_len, num_units]
        
        # 对远距离应用衰减
        mask = (relative.abs() > self.max_distance).float()
        decay = self.decay_weight ** (relative.abs() - self.max_distance).clamp(min=0)
        decay = decay.unsqueeze(0).unsqueeze(-1)  # [1, seq_len, 1]
        r = r * (1 - mask).unsqueeze(-1) + r * mask * decay
        
        return r

注意力中的使用方式

def shaw_attention(query, key, value, r, d_k):
    """
    Shaw相对位置感知的注意力计算
    
    相对位置嵌入通过分解因子与注意力分数交互
    """
    # 标准注意力分数
    score = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 相对位置因子
    # query: [B, H, N, D], r: [N, N, num_units]
    # 使用二次型形式: query * R * key
    Qr = torch.einsum('bhnd,ndp->bhndp', query, r)  # [B, H, N, N, num_units]
    Qr = Qr.sum(dim=-1)  # [B, H, N, N]
    
    # 组合
    score = score + Qr / math.sqrt(d_k)
    
    attn = torch.softmax(score, dim=-1)
    return torch.matmul(attn, value)

3.3 XLNet的循环相对位置编码

XLNet提出循环相对位置编码,将相对位置与循环机制结合,支持建模超长距离依赖。5

核心思想:将位置表示分解为内容相关和位置相关的偏置:

实现代码

import torch
import torch.nn as nn
import math
 
class XLNetRelativePosition(nn.Module):
    """
    XLNet循环相对位置编码
    
    特点:
    - 分离内容和位置贡献
    - 支持循环机制建模长距离
    - 可处理超长序列
    """
    def __init__(self, d_model, d_head, max_position=512, bidirectional=True):
        super().__init__()
        self.d_head = d_head
        self.max_position = max_position
        self.bidirectional = bidirectional
        
        # 内容-位置偏置
        self.r = nn.Parameter(torch.randn(max_position, d_head))
        nn.init.normal_(self.r, std=0.02)
        
        # 位置-内容偏置(可学习)
        self.r_bias = nn.Parameter(torch.zeros(max_position, d_head))
    
    def _get_positional_bias(self, seq_len):
        """生成相对位置矩阵"""
        # 生成位置索引
        position_ids = torch.arange(seq_len, device=self.r.device)
        if not self.bidirectional:
            position_i = torch.arange(seq_len, device=self.r.device).unsqueeze(1)
            position_j = torch.arange(seq_len, device=self.r.device).unsqueeze(0)
            relative_position = position_j - position_i  # 只考虑向后看
        else:
            relative_position = position_ids.unsqueeze(1) - position_ids.unsqueeze(0)
        
        # 截断
        relative_position = torch.clamp(relative_position, 0, self.max_position - 1)
        return relative_position
    
    def forward(self, query, key, content_bias=None):
        """
        Args:
            query: [batch_size, num_heads, seq_len, d_head] (content query)
            key: [batch_size, num_heads, seq_len, d_head] (content key)
            content_bias: 额外的content偏置
        Returns:
            注意力分数的偏置部分
        """
        seq_len = query.size(2)
        rel_pos = self._get_positional_bias(seq_len)
        
        # 位置-内容偏置
        # r_bias: [max_position, d_head]
        # r: [max_position, d_head]
        r_bias = self.r_bias[:seq_len]  # [seq_len, d_head]
        r_bias = r_bias.unsqueeze(1)  # [1, 1, seq_len, d_head]
        
        # 内容-位置偏置
        r = self.r[rel_pos]  # [seq_len, seq_len, d_head]
        r = r.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, seq_len, d_head]
        
        # 内容-内容 + 内容-位置 + 位置-内容 + 位置-位置
        # 简化为:r_bias (位置-内容) + r (内容-位置)
        bias = torch.einsum('bhqd,nd->bhqn', query, r.squeeze(0).squeeze(0))
        bias = bias + torch.einsum('bhld,nd->bhln', key, r_bias.squeeze(0).squeeze(0))
        
        return bias
 
 
class XLNetRelativeAttention(nn.Module):
    """
    XLNet相对注意力机制
    """
    def __init__(self, d_model, num_heads, max_position=512):
        super().__init__()
        self.d_head = d_model // num_heads
        self.num_heads = num_heads
        
        self.q = nn.Linear(d_model, d_model)
        self.k = nn.Linear(d_model, d_model)
        self.v = nn.Linear(d_model, d_model)
        self.o = nn.Linear(d_model, d_model)
        
        self.rel_pos = XLNetRelativePosition(d_model, self.d_head, max_position)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # QKV投影
        q = self.q(x).view(batch_size, seq_len, self.num_heads, self.d_head)
        k = self.k(x).view(batch_size, seq_len, self.num_heads, self.d_head)
        v = self.v(x).view(batch_size, seq_len, self.num_heads, self.d_head)
        
        # 调整维度
        q = q.transpose(1, 2)  # [B, H, N, D]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # 内容-内容注意力
        score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
        
        # 相对位置偏置
        rel_bias = self.rel_pos(q, k)
        score = score + rel_bias
        
        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)
        
        attn = torch.softmax(score, dim=-1)
        output = torch.matmul(attn, v)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.o(output)

XLNet循环机制:XLNet的关键创新在于使用循环来实现对长序列的建模:

def xlnet_layer_forward(x, prev_kv, memory_length=0):
    """
    XLNet层的循环前向传播
    
    利用前一时刻的KV状态实现跨segment的信息传递
    """
    seq_len = x.size(1)
    
    # 如果有memory,拼接
    if prev_kv is not None:
        x = torch.cat([prev_kv, x], dim=1)
    
    # 当前层的计算
    output = relative_attention(x, x, x)
    
    # 保存KV用于下一时刻
    new_kv = output[:, -seq_len:]
    
    return output[:, :seq_len], new_kv

4. RoPE旋转位置编码

RoPE(Rotary Position Embedding,旋转位置编码)由Su等人提出,通过旋转矩阵编码位置信息,是当前LLM最广泛使用的位置编码方案。6

4.1 旋转矩阵的几何意义

RoPE的核心思想是对query和key向量施加旋转变换,使其内积隐式编码相对位置。

二维旋转矩阵

几何解释:向量 左乘旋转矩阵 等价于将向量逆时针旋转 角度,而不改变其长度:

4.2 数学推导:如何实现相对位置

4.2.1 目标

我们希望找到一种编码方式,使得旋转后的query和key的内积只依赖于相对位置

4.2.2 二维情况的推导

,应用旋转:

类似地:

计算内积:

关键性质:旋转矩阵的正交性

因此:

这表明内积仅依赖相对位置

4.2.3 高维扩展

对于 维向量,将维度两两分组,每组独立应用二维旋转:

其中旋转角度为:

4.3 实现细节与代码

import torch
import torch.nn as nn
import math
 
class RotaryPositionEmbedding(nn.Module):
    """
    旋转位置编码(RoPE)
    
    通过旋转矩阵编码位置,使注意力内积隐式包含相对位置信息
    
    特点:
    - 无额外参数
    - 支持任意长度外推(理论上)
    - 可与GQA等注意力变体兼容
    - 训练和推理开销小
    """
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        
        # 逆频率:1 / base^{2i/dim}
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
    
    def forward(self, max_seq_len, device=None):
        """
        预计算所有位置的旋转角
        
        Args:
            max_seq_len: 最大序列长度
            device: 设备
        Returns:
            cos: 旋转角的余弦值
            sin: 旋转角的正弦值
        """
        if device is None:
            device = self.inv_freq.device
        
        # 位置索引:[0, 1, 2, ..., max_seq_len-1]
        seq = torch.arange(max_seq_len, device=device)
        
        # 外积得到每对(位置, 维度)的频率
        # seq: [max_seq_len], inv_freq: [dim/2]
        # freqs: [max_seq_len, dim/2]
        freqs = torch.outer(seq, self.inv_freq)
        
        # 拼接余弦和正弦:[max_seq_len, dim]
        emb = torch.cat([freqs, freqs], dim=-1)
        
        return emb.cos(), emb.sin()
    
    @staticmethod
    def rotate_half(x):
        """
        旋转半个维度
        
        将x的后半部分取负并前置
        等价于:[-x_{d/2:}, x_{:d/2}]
        """
        x1 = x[..., :x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2:]
        return torch.cat([-x2, x1], dim=-1)
    
    def apply_rotary(self, q, k, cos, sin, position_ids=None):
        """
        应用旋转到query和key
        
        Args:
            q: [batch_size, num_heads, seq_len, head_dim]
            k: [batch_size, num_heads, seq_len, head_dim]
            cos: [seq_len, head_dim] 或 [batch_size, seq_len, head_dim]
            sin: [seq_len, head_dim] 或 [batch_size, seq_len, head_dim]
            position_ids: 位置索引(用于非顺序输入)
        Returns:
            旋转后的q和k
        """
        # 确保cos/sin维度正确
        if cos.dim() == 2:
            # [seq_len, dim] -> [1, 1, seq_len, dim]
            cos = cos.unsqueeze(0).unsqueeze(0)
            sin = sin.unsqueeze(0).unsqueeze(0)
        
        # 应用旋转公式:x' = x * cos + rotate_half(x) * sin
        q_embed = q * cos + self.rotate_half(q) * sin
        k_embed = k * cos + self.rotate_half(k) * sin
        
        return q_embed, k_embed
 
 
class RoPEMultiHeadAttention(nn.Module):
    """
    带RoPE的多头注意力
    """
    def __init__(self, d_model, num_heads, base=10000):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        self.rope = RotaryPositionEmbedding(self.d_k, base)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # QKV投影
        q = self.w_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        k = self.w_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        v = self.w_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # 应用RoPE
        cos, sin = self.rope(seq_len, device=x.device)
        q, k = self.rope.apply_rotary(q, k, cos, sin)
        
        # 注意力计算
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, v)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.w_o(output)

4.4 2D RoPE vs 高维RoPE

4.4.1 2D RoPE(复数形式)

2D RoPE使用复数乘法实现旋转:

这等价于将query视为复数,与旋转因子的乘积。

4.4.2 高维RoPE

对于 维RoPE,按对角块分组:

维度分组策略

分组方式描述适用场景
相邻配对标准RoPE
间隔配对LLaMA采用
全连接所有维度互相作用计算开销大

LLaMA的RoPE实现

class LlamaRotaryEmbedding(nn.Module):
    """
    LLaMA风格的RoPE
    
    使用间隔配对策略,与GPT-NeoX相同
    """
    def __init__(self, dim, max_position=2048, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        self.max_position = max_position
        
        # 计算逆频率
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        
        # 预计算(可选)
        self._compute_cos_sin(max_position)
    
    def _compute_cos_sin(self, seq_len):
        """预计算所有位置的cos和sin"""
        position = torch.arange(seq_len)
        freqs = torch.outer(position, self.inv_freq)
        self.register_buffer('cos_cached', freqs.cos())
        self.register_buffer('sin_cached', freqs.sin())
    
    def forward(self, x, position_ids=None):
        """
        Args:
            x: [batch_size, num_heads, seq_len, head_dim]
        """
        if position_ids is None:
            seq_len = x.shape[2]
            position_ids = torch.arange(seq_len, device=x.device)
        
        # 获取频率
        freqs = torch.outer(position_ids, self.inv_freq)
        # 拼接
        emb = torch.cat([freqs, freqs], dim=-1)
        cos = emb.cos()
        sin = emb.sin()
        
        return cos, sin

5. ALiBi线性偏置

ALiBi(Attention with Linear Biases)由Press等人提出,是一种无需学习的自然外推位置编码方案。7

5.1 无需学习的方案

ALiBi的核心设计原则是零额外参数,通过简单的线性函数将位置信息注入注意力分数。

基础公式

其中 是衰减系数,控制位置距离对注意力分数的影响。

5.2 数学推导:线性衰减

5.2.1 多头ALiBi

对于多头注意力,每个头可有独立的衰减系数:

衰减系数设置(论文推荐):

这种设计使得不同头捕获不同尺度的位置关系:低索引头对位置更敏感,高索引头对位置更迟钝。

5.2.2 衰减特性分析

特性分析
局部注意力短距离token获得较高注意力分数
远程衰减距离越大,注意力分数越低
线性衰减衰减速度快于指数衰减
方向无关使用绝对距离,忽略方向

5.3 与RoPE的对比分析

维度RoPEALiBi
编码方式旋转Q/K向量注意力分数加偏置
参数量极小(仅频率参数)零参数
实现复杂度中等极简
相对位置隐式编码显式编码
外推能力需NTK-Scaling天然外推
与注意力融合旋转操作加法偏置
GQA兼容性原生支持需修改偏置生成
Flash Attention可集成可集成

外推能力对比

场景RoPE(标准)RoPE+NTKALiBi
训练: 1K → 测试: 2K良好优秀优秀
训练: 1K → 测试: 8K退化良好良好
训练: 4K → 测试: 32K严重退化可行可行

代码实现对比

# RoPE注意力
class RoPEAttention(nn.Module):
    def forward(self, q, k, v):
        # 1. QKV投影
        q, k = self.project_q(q), self.project_k(k)
        # 2. 应用RoPE旋转
        q = self.rope.rotate(q)
        k = self.rope.rotate(k)
        # 3. 注意力计算
        return self.attention(q, k, v)
 
# ALiBi注意力
class ALiBiAttention(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        # 预计算衰减斜率
        slopes = torch.pow(2, -8.0 / torch.arange(1, num_heads + 1))
        self.register_buffer('slopes', slopes)
    
    def forward(self, q, k, v):
        # 1. QKV投影
        q, k, v = self.project(q), self.project(k), self.project(v)
        # 2. 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d)
        # 3. 添加ALiBi偏置
        scores = scores + self._get_alibi_bias(q.size(2))
        # 4. Softmax
        return self.attention(scores, v)
    
    def _get_alibi_bias(self, seq_len):
        # 生成距离矩阵
        positions = torch.arange(seq_len)
        distance = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1))
        # 应用斜率衰减
        return -distance.unsqueeze(0) * self.slopes.view(-1, 1, 1)

6. 位置编码外推技术

位置编码外推(Position Encoding Extrapolation)指模型在训练时见过长度为 的序列,但需要在长度为 的序列上进行推理。

6.1 位置编码外推的挑战

6.1.1 外推失败的本质

以可学习位置编码为例:

  • 模型在 范围内学习位置嵌入
  • 外推到 时,无法获得对应嵌入
  • 即使使用Sinusoidal或RoPE,远距离位置也可能因频率混淆导致性能下降

6.1.2 外推问题的数学刻画

问题:设 为位置编码函数, 为从位置编码到输出的映射。若:

但当 时,模型无法正确推广。

6.1.3 失败症状

症状描述
注意力分散远距离token注意力异常增高或降低
困惑度激增长序列的困惑度显著上升
重复生成模型陷入循环生成
位置混淆模型无法区分不同远距离位置

6.2 插值方法:NTK-aware scaling

NTK-aware scaling是一种针对RoPE的插值策略,由Ilyas等人提出。8

6.2.1 核心思想

NTK-aware scaling不直接插值位置编码,而是调整RoPE的频率基底,保持高频不变,仅压缩低频。

6.2.2 数学推导

问题:标准RoPE在高维度使用高频(短周期),低维度使用低频(长周期)。直接插值会破坏高频维度。

解决思路:调整基参数 ,使得:

其中 是缩放因子。

NTK-aware公式

其中 是上下文扩展倍数, 是原始基参数。

def ntk_scaled_rope(base, dim, scale):
    """
    NTK-aware RoPE缩放
    
    Args:
        base: 原始基参数(通常10000)
        dim: 模型维度
        scale: 缩放因子(目标长度/原始长度)
    """
    # 调整后的基参数
    new_base = base * (scale ** (dim / (dim - 2)))
    return new_base
 
 
def apply_ntk_scaling(model, scale):
    """
    对模型应用NTK缩放
    """
    for module in model.modules():
        if hasattr(module, 'rope'):
            old_base = module.rope.base
            new_base = ntk_scaled_rope(old_base, module.rope.dim, scale)
            module.rope.base = new_base
            
            # 更新inv_freq
            module.rope.inv_freq = 1.0 / (new_base ** (torch.arange(0, dim, 2).float() / dim))

6.2.3 NTK-aware vs 简单插值

方法实现效果
位置插值(PI)直接缩放位置索引破坏高频维度
频率插值(FI)缩放所有频率高频信息丢失
NTK-aware仅压缩低频保留高频信息

6.3 插值方法:YaRN

YaRN(Yet another RoPE extensioN)是NTK-aware scaling的改进版本,进一步提升了外推能力。9

6.3.1 核心改进

YaRN在NTK-aware基础上引入两项改进:

  1. 温度因子:对注意力分数应用温度缩放
  2. 拉伸因子:对低频维度应用更激进的拉伸

6.3.2 数学公式

拉伸因子

温度缩放

class YaRNPositionEmbedding(nn.Module):
    """
    YaRN旋转位置编码
    
    结合NTK-aware scaling和温度缩放
    """
    def __init__(self, dim, base=10000, max_position=2048, scale=1.0, extrapolation_factor=1.0):
        super().__init__()
        self.dim = dim
        self.base = base
        self.scale = scale
        
        # YaRN特定的基参数调整
        self.extrapolation_factor = extrapolation_factor
        
        # 计算调整后的基参数
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
    
    def _yarn_invariant_freq(self, scale):
        """计算YaRN不变的频率调整"""
        dim = self.dim
        # 计算拉伸因子
        floor_scale = torch.floor(scale)
        mscale = 0.1 * floor_scale * math.sqrt(1 + 1 / (floor_scale ** 2)) - 1
        
        # 温度因子
        mscale = 1 + mscale * math.log(scale) / (math.log(scale) + 1)
        
        return mscale
    
    def forward(self, x):
        seq_len = x.shape[1]
        position = torch.arange(seq_len, device=x.device)
        
        # 应用YaRN调整
        if self.scale > 1.0:
            mscale = self._yarn_invariant_freq(self.scale)
            position = position * self.scale / mscale
        
        freqs = torch.outer(position, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        
        return emb.cos(), emb.sin()

6.3.3 实验效果

方法困惑度变化长度扩展倍数
标准RoPE急剧上升
位置插值轻微上升2-4×
NTK-aware稳定2-8×
YaRN最优稳定最高32×

6.4 其他外推技术

6.4.1 逐步扩展(逐步微调)

方法:分阶段训练,逐步扩展上下文长度。

def progressive_training_schedule(model, target_len, steps):
    """
    渐进式训练计划
    
    逐步增加上下文长度并微调
    """
    current_len = 512
    scale_factor = (target_len / current_len) ** (1 / steps)
    
    for step in range(steps):
        # 扩展长度
        current_len = int(current_len * scale_factor)
        
        # 调整位置编码
        adjust_position_encoding(model, current_len)
        
        # 微调阶段
        fine_tune(model, train_tokens=current_len * batch_size)

6.4.2 位置丢弃(Position Dropout)

方法:训练时随机丢弃部分位置,增强模型对位置变化的鲁棒性。

class PositionDropout(nn.Module):
    def __init__(self, p=0.1):
        super().__init__()
        self.p = p
    
    def forward(self, x, position_ids=None):
        if self.training and torch.rand(1) < self.p:
            # 随机跳位
            skip = torch.randint(1, 10, (1,)).item()
            x = x[:, skip:]
            if position_ids is not None:
                position_ids = position_ids[:, skip:]
        return x, position_ids

6.4.3 随机位置编码

方法:训练时使用随机采样的位置索引,增强位置编码的泛化能力。

class RandomPositionalEncoding(nn.Module):
    def forward(self, x):
        seq_len = x.shape[1]
        
        # 随机重排位置(仅训练时)
        if self.training:
            position_ids = torch.randperm(seq_len)
        else:
            position_ids = torch.arange(seq_len)
        
        return x, position_ids

7. 混合与最新进展

7.1 CoPE(Conditional Position Encoding)

CoPE由局外人等人提出,是一种根据内容动态计算位置的编码方案。10

7.1.1 核心思想

传统位置编码使用固定的整数位置,CoPE根据token的内容动态决定”位置”——实际上是基于门控值的累积。

7.1.2 数学推导

门控函数

累积位置

注意力分数

其中 是关于累积位置的函数(可学习)。

7.1.3 实现代码

import torch
import torch.nn as nn
import math
 
class ConditionalPositionEncoding(nn.Module):
    """
    条件位置编码(CoPE)
    
    根据内容动态决定位置,而非使用固定整数位置
    """
    def __init__(self, d_model, max_heads=32):
        super().__init__()
        self.max_heads = max_heads
        
        # 门控网络
        self.gate_proj = nn.Linear(d_model, 1)
        
        # 位置函数(可学习)
        self.pos_func = nn.Sequential(
            nn.Linear(1, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
    
    def forward(self, q, k, v):
        """
        Args:
            q: [batch_size, num_heads, seq_len, d_k]
            k: [batch_size, num_heads, seq_len, d_k]
            v: [batch_size, num_heads, seq_len, d_k]
        Returns:
            带CoPE的注意力输出
        """
        batch_size, num_heads, seq_len, d_k = q.shape
        
        # 1. 计算门控值 g_{ij}
        # 将q和k投影到统一空间
        q_gate = q.view(batch_size, num_heads, seq_len, 1, d_k)
        k_gate = k.view(batch_size, num_heads, 1, seq_len, d_k)
        
        # 门控注意力
        gate = torch.sigmoid(self.gate_proj(torch.tanh(q_gate + k_gate)))
        # gate: [batch_size, num_heads, seq_len, seq_len, 1]
        gate = gate.squeeze(-1)  # [batch_size, num_heads, seq_len, seq_len]
        
        # 2. 累积得到位置 p_{ij}
        # 前向累积:p_{ij} = sum_t(g_{it} for t <= j)
        positions = torch.cumsum(gate, dim=-1)  # [batch_size, num_heads, seq_len, seq_len]
        
        # 3. 标准化位置到[0, max_heads]
        positions = positions / (positions[:, :, :, -1:] + 1e-6) * self.max_heads
        positions = positions.clamp(0, self.max_heads)
        
        # 4. 通过位置函数获取偏置
        positions_flat = positions.view(batch_size, num_heads, -1, 1)  # [B, H, N^2, 1]
        pos_bias = self.pos_func(positions_flat)  # [B, H, N^2, d_k]
        
        # 5. 标准注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        
        # 6. 添加位置偏置
        pos_bias = pos_bias.view(batch_size, num_heads, seq_len, seq_len)
        scores = scores + pos_bias
        
        # 7. 注意力计算
        attn = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn, v)
        
        return output

7.1.4 CoPE vs 其他编码

特性绝对PE相对PERoPEALiBiCoPE
位置定义固定整数固定整数固定整数固定整数动态
内容感知
参数量可变
实现复杂度

7.2 位置编码的选择建议

根据不同应用场景,以下是位置编码的选择指南:

7.2.1 按任务类型选择

任务类型推荐方案理由
短文本分类可学习PE参数足够,训练简单
标准语言建模RoPE成熟稳定,主流选择
代码生成RoPE + NTK需要长距离理解
对话系统ALiBi或RoPE变长对话需外推
多模态(视频/音频)相对PE或CoPE时间序列需灵活位置
超长上下文ALiBi + NTK极端长度需双重保障

7.2.2 按模型规模选择

模型规模推荐方案备注
小型(<1B)ALiBi或可学习PE参数效率优先
中型(1B-70B)RoPE平衡效果与效率
超大型(>70B)RoPE + GQA + NTK高效推理+长上下文

7.2.3 按推理场景选择

推理场景推荐方案
固定长度推理任选
批处理变长RoPE或ALiBi
流式推理RoPE
超长序列ALiBi或NTK-Scaled RoPE

7.3 未来研究方向

7.3.1 自适应位置编码

根据输入内容自动调整位置编码策略:

class AdaptivePositionalEncoding(nn.Module):
    """
    自适应位置编码
    
    根据序列特点选择最优编码方式
    """
    def __init__(self, d_model):
        super().__init__()
        self.rope = RotaryPositionEmbedding(d_model)
        self.alibi_slopes = nn.Parameter(torch.ones(8))  # 可学习的ALiBi斜率
        self.selector = nn.Linear(d_model, 3)  # 选择器
    
    def forward(self, x, mode=None):
        logits = self.selector(x.mean(dim=1))  # [batch_size, 3]
        
        if mode == 'rope':
            return self.rope(x)
        elif mode == 'alibi':
            return self._apply_alibi(x)
        else:
            # 软选择
            probs = F.softmax(logits, dim=-1)
            rope_out = self.rope(x)
            alibi_out = self._apply_alibi(x)
            return probs[:, 0:1] * rope_out + probs[:, 1:2] * alibi_out

7.3.2 任务导向位置编码

针对特定任务(如检索、推理)设计专用位置编码:

class TaskAwarePositionalEncoding(nn.Module):
    """
    任务感知位置编码
    
    根据任务类型调整位置表示
    """
    def __init__(self, d_model, num_tasks):
        super().__init__()
        self.task_embeddings = nn.Embedding(num_tasks, d_model)
        self.rope = RotaryPositionEmbedding(d_model)
    
    def forward(self, x, task_id):
        # 获取任务嵌入
        task_emb = self.task_embeddings(task_id)
        
        # 任务感知的RoPE
        modified_freq = self.rope.inv_freq * (1 + task_emb)
        # ... 应用修改后的旋转

7.3.3 可组合位置编码

探索不同位置编码方案的组合优势:

组合方案预期优势
RoPE + ALiBi精确相对位置 + 天然外推
RoPE + CoPE旋转不变性 + 内容感知
ALiBi + 层次PE线性衰减 + 多尺度
class HybridPositionalEncoding(nn.Module):
    """
    混合位置编码
    
    结合RoPE和ALiBi的优势
    """
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.rope = RotaryPositionEmbedding(d_model)
        self.alibi_slopes = self._get_alibi_slopes(num_heads)
    
    def _get_alibi_slopes(self, num_heads):
        slopes = torch.pow(2, -8.0 / torch.arange(1, num_heads + 1))
        return nn.Parameter(slopes)
    
    def forward(self, q, k, v):
        # 1. 应用RoPE
        q, k = self.rope.rotate(q), self.rope.rotate(k)
        
        # 2. 计算ALiBi偏置
        seq_len = q.shape[2]
        positions = torch.arange(seq_len, device=q.device)
        distance = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1))
        alibi_bias = -distance.unsqueeze(0) * self.alibi_slopes.view(-1, 1, 1)
        
        # 3. 注意力计算
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
        scores = scores + alibi_bias
        
        return torch.matmul(torch.softmax(scores, -1), v)

8. 参考资料


相关词条:位置编码的几何理论ALiBi位置编码深度解析Transformer与注意力机制稀疏注意力与长度外推LLM架构族对比FlashAttention深度解析

Footnotes

  1. Vaswani et al., “Attention Is All You Need”, NeurIPS 2017

  2. Devlin et al., “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding”, NAACL 2019

  3. Raffel et al., “Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer”, JMLR 2020

  4. Shaw et al., “Self-Attention with Relative Position Representations”, NAACL 2018

  5. Yang et al., “XLNet: Generalized Autoregressive Pretraining for Language Understanding”, NeurIPS 2019

  6. Su et al., “RoFormer: Enhanced Transformer with Rotary Position Embedding”, arXiv:2104.09864, 2021

  7. Press et al., “Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation”, ICLR 2022

  8. Ilyas et al., “Rotary Position Embeddings for Transformer”, arXiv, 2023

  9. Peng et al., “YaRN: Efficient Context Window Extension of LLMs”, arXiv, 2023

  10. 局外人等, “Conditional Positional Encoding”, ICLR 2022