概述

ALiBi(Attention with Linear Biases)是一种创新的位置编码方法,通过在注意力分数上添加线性偏置来编码位置信息,无需显式的位置嵌入向量。1 ALiBi最初由Press等人于2021年提出,在训练效率和外推能力上展现出显著优势。

核心思想

与传统位置编码不同,ALiBi的核心创新在于:

  1. 无额外参数:不引入可学习的位置嵌入
  2. 线性偏置:直接在注意力分数上添加位置相关的偏置
  3. 天然外推:支持处理比训练更长的序列

设计动机

传统位置编码的问题

1. 可学习位置嵌入

问题:
- 需要预定义最大长度
- 无法外推到更长序列
- 增加参数量

2. Sinusoidal位置编码

问题:
- 与词嵌入难以融合
- 外推效果不理想
- 计算开销

3. RoPE(旋转位置编码)

优点:
- 相对位置编码
- 支持外推

限制:
- 需要修改Q/K向量
- 实现复杂度较高

ALiBi的设计目标

  1. 简单性:无需修改注意力计算的核心逻辑
  2. 高效性:零额外参数
  3. 外推性:自然支持长度外推
  4. 兼容性:可与FlashAttention等优化技术无缝结合

数学推导

基础公式

对于序列中的位置 ,ALiBi添加的偏置为:

带衰减的ALiBi

实际使用中,使用基于距离的衰减:

其中 是可学习的每头衰减系数。

多头注意力集成

对于多头注意力,每个头可以有独立的衰减系数:

其中 是第 个头的衰减斜率。

衰减系数设置

论文建议的衰减系数配置:

这种设置使不同头的衰减速率不同,有助于捕获多尺度的位置关系。


与RoPE的对比

编码方式对比

特性RoPEALiBi
编码位置旋转Q/K向量注意力分数加偏置
参数需求可选可学习零参数
实现复杂度中等极简
与词嵌入融合逐元素乘法加法
理论优雅性

数学性质对比

RoPE

ALiBi

外推能力对比

场景RoPEALiBi
训练: 1K → 测试: 2K良好优秀
训练: 1K → 测试: 8K需要NTK-Scaling良好
训练: 4K → 测试: 32K可行可行

适用场景选择

场景推荐
短序列,精度要求高ALiBi
长序列,需要精细控制RoPE + NTK-Scaling
实现简单性优先ALiBi
需要相对位置精确建模RoPE

PyTorch实现

基础ALiBi实现

import torch
import torch.nn as nn
import math
 
class ALiBiAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1, max_seq_len=1024):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        # ALiBi偏置矩阵
        self.register_buffer('alibi_bias', self._generate_alibi_bias(num_heads, max_seq_len))
    
    def _generate_alibi_bias(self, num_heads, max_seq_len):
        """生成ALiBi偏置矩阵"""
        # 创建相对距离矩阵: (max_seq_len, max_seq_len)
        position = torch.arange(max_seq_len).unsqueeze(0)  # (1, seq_len)
        relative_pos = (position - position.T)  # (seq_len, seq_len), 值范围 [-seq_len+1, seq_len-1]
        
        # 绝对值: |i - j|
        distance = torch.abs(relative_pos)
        
        # 计算衰减系数: m_h = 2^{-8/h}
        # 生成每个头的斜率
        slopes = torch.pow(2, -8.0 / torch.arange(1, num_heads + 1).float())
        
        # 对于超过8个头的情况,使用slope=1
        slopes = torch.where(
            torch.arange(1, num_heads + 1).float() > 8,
            torch.ones(num_heads),
            slopes
        )
        
        # 计算偏置: -|i-j| * m_h
        # slopes: (num_heads,)
        # distance: (max_seq_len, max_seq_len)
        # result: (num_heads, max_seq_len, max_seq_len)
        alibi_bias = -distance.unsqueeze(0) * slopes.view(-1, 1, 1)
        
        return alibi_bias
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        seq_len = query.size(1)
        
        # QKV投影
        Q = self.W_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # 注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # 应用ALiBi偏置
        alibi_bias = self.alibi_bias[:self.num_heads, :seq_len, :seq_len]
        scores = scores + alibi_bias.unsqueeze(0)
        
        # 应用mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # 加权求和
        output = torch.matmul(attn_weights, V)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        return self.W_o(output)

与FlashAttention集成

import torch
import torch.nn.functional as F
try:
    from flash_attn import flash_attn_func
    FLASH_AVAILABLE = True
except ImportError:
    FLASH_AVAILABLE = False
 
class FlashALiBiAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_qkv = nn.Linear(d_model, 3 * d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        # ALiBi斜率
        self.register_buffer('slopes', self._get_slopes(num_heads))
    
    def _get_slopes(self, num_heads):
        """获取每个头的衰减斜率"""
        slopes = torch.pow(2, -8.0 / torch.arange(1, num_heads + 1).float())
        slopes = torch.where(
            torch.arange(1, num_heads + 1).float() > 8,
            torch.ones(num_heads),
            slopes
        )
        return slopes
    
    def forward(self, x, mask=None):
        B, N, C = x.shape
        
        # QKV投影
        qkv = self.W_qkv(x).reshape(B, N, 3, self.num_heads, self.d_k)
        Q, K, V = qkv[..., 0, :, :], qkv[..., 1, :, :], qkv[..., 2, :, :]
        
        if FLASH_AVAILABLE:
            # FlashAttention: 集成ALiBi需要自定义kernel
            # 这里使用简化版本
            output = flash_attn_func(Q, K, V, causal=True)
        else:
            # 标准实现
            scores = torch.einsum('bhnd,bhmd->bhnm', Q, K) / math.sqrt(self.d_k)
            
            # ALiBi偏置
            positions = torch.arange(N, device=x.device).unsqueeze(0)
            distance = torch.abs(positions - positions.T)
            alibi_bias = -distance.unsqueeze(0) * self.slopes.view(-1, 1, 1)
            
            scores = scores + alibi_bias
            
            if mask is not None:
                scores = scores.masked_fill(mask == 0, float('-inf'))
            
            attn_weights = F.softmax(scores, dim=-1)
            output = torch.einsum('bhnm,bhmd->bhnd', attn_weights, V)
        
        output = output.reshape(B, N, C)
        return self.W_o(output)

Transformer层集成

class ALiBiTransformerLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = ALiBiAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Pre-norm架构
        attn_output = self.attention(self.norm1(x), mask=mask)
        x = x + self.dropout(attn_output)
        
        ffn_output = self.ffn(self.norm2(x))
        x = x + self.dropout(ffn_output)
        
        return x

外推能力分析

长度外推实验设置

训练长度1K tokens
测试长度1K, 2K, 4K, 8K

外推机制

ALiBi的外推能力来源于:

1. 线性偏置的连续性

位置偏置 连续的线性函数:

训练位置范围: [0, 1024)
外推位置范围: [0, 8192)

对于位置对 (i=5000, j=0):
- 训练时从未见过此距离
- 但偏置 = -(5000-0) * m
- 仍然是有效的负数偏置

2. 衰减机制

较大的距离获得较大的负偏置,注意力自然分散:

这意味着远距离token的注意力趋于零,避免噪声。

3. 软饱和

Softmax的饱和特性使外推更加平滑:

外推效果对比

方法1K→2K1K→4K1K→8K
Learned PE严重退化无法使用无法使用
Sinusoidal轻度退化中度退化严重退化
RoPE良好需NTK-Scaling需NTK-Scaling
ALiBi优秀良好可接受

改进方向

动态缩放

class DynamicALiBi(nn.Module):
    def __init__(self, num_heads, max_len=8192):
        super().__init__()
        self.num_heads = num_heads
        self.register_buffer('slopes', self._get_slopes(num_heads))
        
        # 可学习的动态缩放
        self.scale_factor = nn.Parameter(torch.ones(1))
    
    def forward(self, seq_len):
        positions = torch.arange(seq_len)
        distance = torch.abs(positions.unsqueeze(1) - positions.unsqueeze(0))
        bias = -distance * self.slopes.view(-1, 1, 1) * self.scale_factor
        return bias

渐进式衰减

def progressive_alibi_bias(seq_len, num_heads):
    """
    渐进式衰减:距离越远,衰减越慢
    """
    positions = torch.arange(seq_len)
    distance = positions.unsqueeze(1) - positions.unsqueeze(0)
    distance = torch.clamp(-distance, 0, seq_len)  # 只考虑未来位置
    
    slopes = torch.pow(2, -8 / torch.arange(1, num_heads + 1).float())
    
    # 渐进衰减因子
    decay = 1 / (1 + distance.float() / 1024)  # 距离越大,衰减越小
    bias = -distance.float() * slopes.view(-1, 1, 1) * decay
    
    return bias

变体与发展

1. T5 Relative Position Bias

T5的相对位置编码是ALiBi的先驱:

其中 是可学习的嵌入表。

2. DeBERTa的相对位置

DeBERTa使用解耦的相对位置编码:

3. CoPE (Conditional Position Encoding)

CoPE根据内容动态决定位置编码:

4. XPos (Exponential Decay Position Encoding)

XPos使用指数衰减替代线性衰减:


实战调参指南

关键超参数

参数推荐值说明
衰减斜率 可微调
头数8-16更多头更精细
最大长度2×训练长度留外推空间

训练策略

1. 位置Dropout

class ALiBiWithDropout(nn.Module):
    def __init__(self, base_attention, dropout_prob=0.1):
        super().__init__()
        self.attention = base_attention
        self.dropout_prob = dropout_prob
    
    def forward(self, x, mask=None):
        # 随机丢弃部分位置信息
        if self.training and self.dropout_prob > 0:
            B, N, C = x.shape
            drop_mask = torch.rand(B, N, 1, device=x.device) > self.dropout_prob
            x = x * drop_mask.float()
        
        return self.attention(x, mask)

2. 混合训练长度

def collate_with_padding(batch, max_len=None):
    """动态最大长度,有助于外推"""
    if max_len is None:
        max_len = max([x['input_ids'].shape[0] for x in batch])
    max_len = ((max_len - 1) // 32 + 1) * 32  # 对齐到32
    
    # padding
    for x in batch:
        pad_len = max_len - x['input_ids'].shape[0]
        x['input_ids'] = F.pad(x['input_ids'], (0, pad_len))
    
    return default_collate(batch)

与其他位置编码的融合

ALiBi + RoPE

虽然两者都是位置编码,但可以组合使用:

class HybridPositionEncoding(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.rope = RotaryPositionEmbedding(d_model)
        self.alibi_slopes = self._get_slopes(num_heads)
    
    def forward(self, Q, K, V, seq_len):
        # 先应用RoPE
        Q = self.rope.rotate_queries(Q)
        K = self.rope.rotate_keys(K)
        
        # 再应用ALiBi
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.shape[-1])
        alibi_bias = self._get_alibi_bias(seq_len)
        scores = scores + alibi_bias
        
        return torch.softmax(scores, dim=-1) @ V

参考


相关词条:Transformer数学基础稀疏注意力与长度外推Vision Transformer详解

Footnotes

  1. Press et al., “Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation”, ICLR 2022