概述
ALiBi(Attention with Linear Biases)是一种创新的位置编码方法,通过在注意力分数上添加线性偏置来编码位置信息,无需显式的位置嵌入向量。1 ALiBi最初由Press等人于2021年提出,在训练效率和外推能力上展现出显著优势。
核心思想
与传统位置编码不同,ALiBi的核心创新在于:
- 无额外参数:不引入可学习的位置嵌入
- 线性偏置:直接在注意力分数上添加位置相关的偏置
- 天然外推:支持处理比训练更长的序列
设计动机
传统位置编码的问题
1. 可学习位置嵌入
问题:
- 需要预定义最大长度
- 无法外推到更长序列
- 增加参数量
2. Sinusoidal位置编码
问题:
- 与词嵌入难以融合
- 外推效果不理想
- 计算开销
3. RoPE(旋转位置编码)
优点:
- 相对位置编码
- 支持外推
限制:
- 需要修改Q/K向量
- 实现复杂度较高
ALiBi的设计目标
- 简单性:无需修改注意力计算的核心逻辑
- 高效性:零额外参数
- 外推性:自然支持长度外推
- 兼容性:可与FlashAttention等优化技术无缝结合
数学推导
基础公式
对于序列中的位置 和 ,ALiBi添加的偏置为:
带衰减的ALiBi
实际使用中,使用基于距离的衰减:
其中 是可学习的每头衰减系数。
多头注意力集成
对于多头注意力,每个头可以有独立的衰减系数:
其中 , 是第 个头的衰减斜率。
衰减系数设置
论文建议的衰减系数配置:
这种设置使不同头的衰减速率不同,有助于捕获多尺度的位置关系。
与RoPE的对比
编码方式对比
| 特性 | RoPE | ALiBi |
|---|---|---|
| 编码位置 | 旋转Q/K向量 | 注意力分数加偏置 |
| 参数需求 | 可选可学习 | 零参数 |
| 实现复杂度 | 中等 | 极简 |
| 与词嵌入融合 | 逐元素乘法 | 加法 |
| 理论优雅性 | 高 | 中 |
数学性质对比
RoPE
ALiBi
外推能力对比
| 场景 | RoPE | ALiBi |
|---|---|---|
| 训练: 1K → 测试: 2K | 良好 | 优秀 |
| 训练: 1K → 测试: 8K | 需要NTK-Scaling | 良好 |
| 训练: 4K → 测试: 32K | 可行 | 可行 |
适用场景选择
| 场景 | 推荐 |
|---|---|
| 短序列,精度要求高 | ALiBi |
| 长序列,需要精细控制 | RoPE + NTK-Scaling |
| 实现简单性优先 | ALiBi |
| 需要相对位置精确建模 | RoPE |
PyTorch实现
基础ALiBi实现
import torch
import torch.nn as nn
import math
class ALiBiAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1, max_seq_len=1024):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
# ALiBi偏置矩阵
self.register_buffer('alibi_bias', self._generate_alibi_bias(num_heads, max_seq_len))
def _generate_alibi_bias(self, num_heads, max_seq_len):
"""生成ALiBi偏置矩阵"""
# 创建相对距离矩阵: (max_seq_len, max_seq_len)
position = torch.arange(max_seq_len).unsqueeze(0) # (1, seq_len)
relative_pos = (position - position.T) # (seq_len, seq_len), 值范围 [-seq_len+1, seq_len-1]
# 绝对值: |i - j|
distance = torch.abs(relative_pos)
# 计算衰减系数: m_h = 2^{-8/h}
# 生成每个头的斜率
slopes = torch.pow(2, -8.0 / torch.arange(1, num_heads + 1).float())
# 对于超过8个头的情况,使用slope=1
slopes = torch.where(
torch.arange(1, num_heads + 1).float() > 8,
torch.ones(num_heads),
slopes
)
# 计算偏置: -|i-j| * m_h
# slopes: (num_heads,)
# distance: (max_seq_len, max_seq_len)
# result: (num_heads, max_seq_len, max_seq_len)
alibi_bias = -distance.unsqueeze(0) * slopes.view(-1, 1, 1)
return alibi_bias
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
seq_len = query.size(1)
# QKV投影
Q = self.W_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# 注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# 应用ALiBi偏置
alibi_bias = self.alibi_bias[:self.num_heads, :seq_len, :seq_len]
scores = scores + alibi_bias.unsqueeze(0)
# 应用mask
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 加权求和
output = torch.matmul(attn_weights, V)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.W_o(output)与FlashAttention集成
import torch
import torch.nn.functional as F
try:
from flash_attn import flash_attn_func
FLASH_AVAILABLE = True
except ImportError:
FLASH_AVAILABLE = False
class FlashALiBiAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_qkv = nn.Linear(d_model, 3 * d_model)
self.W_o = nn.Linear(d_model, d_model)
# ALiBi斜率
self.register_buffer('slopes', self._get_slopes(num_heads))
def _get_slopes(self, num_heads):
"""获取每个头的衰减斜率"""
slopes = torch.pow(2, -8.0 / torch.arange(1, num_heads + 1).float())
slopes = torch.where(
torch.arange(1, num_heads + 1).float() > 8,
torch.ones(num_heads),
slopes
)
return slopes
def forward(self, x, mask=None):
B, N, C = x.shape
# QKV投影
qkv = self.W_qkv(x).reshape(B, N, 3, self.num_heads, self.d_k)
Q, K, V = qkv[..., 0, :, :], qkv[..., 1, :, :], qkv[..., 2, :, :]
if FLASH_AVAILABLE:
# FlashAttention: 集成ALiBi需要自定义kernel
# 这里使用简化版本
output = flash_attn_func(Q, K, V, causal=True)
else:
# 标准实现
scores = torch.einsum('bhnd,bhmd->bhnm', Q, K) / math.sqrt(self.d_k)
# ALiBi偏置
positions = torch.arange(N, device=x.device).unsqueeze(0)
distance = torch.abs(positions - positions.T)
alibi_bias = -distance.unsqueeze(0) * self.slopes.view(-1, 1, 1)
scores = scores + alibi_bias
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhnm,bhmd->bhnd', attn_weights, V)
output = output.reshape(B, N, C)
return self.W_o(output)Transformer层集成
class ALiBiTransformerLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = ALiBiAttention(d_model, num_heads, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Pre-norm架构
attn_output = self.attention(self.norm1(x), mask=mask)
x = x + self.dropout(attn_output)
ffn_output = self.ffn(self.norm2(x))
x = x + self.dropout(ffn_output)
return x外推能力分析
长度外推实验设置
| 训练长度 | 1K tokens |
|---|---|
| 测试长度 | 1K, 2K, 4K, 8K |
外推机制
ALiBi的外推能力来源于:
1. 线性偏置的连续性
位置偏置 是连续的线性函数:
训练位置范围: [0, 1024)
外推位置范围: [0, 8192)
对于位置对 (i=5000, j=0):
- 训练时从未见过此距离
- 但偏置 = -(5000-0) * m
- 仍然是有效的负数偏置
2. 衰减机制
较大的距离获得较大的负偏置,注意力自然分散:
这意味着远距离token的注意力趋于零,避免噪声。
3. 软饱和
Softmax的饱和特性使外推更加平滑:
外推效果对比
| 方法 | 1K→2K | 1K→4K | 1K→8K |
|---|---|---|---|
| Learned PE | 严重退化 | 无法使用 | 无法使用 |
| Sinusoidal | 轻度退化 | 中度退化 | 严重退化 |
| RoPE | 良好 | 需NTK-Scaling | 需NTK-Scaling |
| ALiBi | 优秀 | 良好 | 可接受 |
改进方向
动态缩放
class DynamicALiBi(nn.Module):
def __init__(self, num_heads, max_len=8192):
super().__init__()
self.num_heads = num_heads
self.register_buffer('slopes', self._get_slopes(num_heads))
# 可学习的动态缩放
self.scale_factor = nn.Parameter(torch.ones(1))
def forward(self, seq_len):
positions = torch.arange(seq_len)
distance = torch.abs(positions.unsqueeze(1) - positions.unsqueeze(0))
bias = -distance * self.slopes.view(-1, 1, 1) * self.scale_factor
return bias渐进式衰减
def progressive_alibi_bias(seq_len, num_heads):
"""
渐进式衰减:距离越远,衰减越慢
"""
positions = torch.arange(seq_len)
distance = positions.unsqueeze(1) - positions.unsqueeze(0)
distance = torch.clamp(-distance, 0, seq_len) # 只考虑未来位置
slopes = torch.pow(2, -8 / torch.arange(1, num_heads + 1).float())
# 渐进衰减因子
decay = 1 / (1 + distance.float() / 1024) # 距离越大,衰减越小
bias = -distance.float() * slopes.view(-1, 1, 1) * decay
return bias变体与发展
1. T5 Relative Position Bias
T5的相对位置编码是ALiBi的先驱:
其中 是可学习的嵌入表。
2. DeBERTa的相对位置
DeBERTa使用解耦的相对位置编码:
3. CoPE (Conditional Position Encoding)
CoPE根据内容动态决定位置编码:
4. XPos (Exponential Decay Position Encoding)
XPos使用指数衰减替代线性衰减:
实战调参指南
关键超参数
| 参数 | 推荐值 | 说明 |
|---|---|---|
| 衰减斜率 | 可微调 | |
| 头数 | 8-16 | 更多头更精细 |
| 最大长度 | 2×训练长度 | 留外推空间 |
训练策略
1. 位置Dropout
class ALiBiWithDropout(nn.Module):
def __init__(self, base_attention, dropout_prob=0.1):
super().__init__()
self.attention = base_attention
self.dropout_prob = dropout_prob
def forward(self, x, mask=None):
# 随机丢弃部分位置信息
if self.training and self.dropout_prob > 0:
B, N, C = x.shape
drop_mask = torch.rand(B, N, 1, device=x.device) > self.dropout_prob
x = x * drop_mask.float()
return self.attention(x, mask)2. 混合训练长度
def collate_with_padding(batch, max_len=None):
"""动态最大长度,有助于外推"""
if max_len is None:
max_len = max([x['input_ids'].shape[0] for x in batch])
max_len = ((max_len - 1) // 32 + 1) * 32 # 对齐到32
# padding
for x in batch:
pad_len = max_len - x['input_ids'].shape[0]
x['input_ids'] = F.pad(x['input_ids'], (0, pad_len))
return default_collate(batch)与其他位置编码的融合
ALiBi + RoPE
虽然两者都是位置编码,但可以组合使用:
class HybridPositionEncoding(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.rope = RotaryPositionEmbedding(d_model)
self.alibi_slopes = self._get_slopes(num_heads)
def forward(self, Q, K, V, seq_len):
# 先应用RoPE
Q = self.rope.rotate_queries(Q)
K = self.rope.rotate_keys(K)
# 再应用ALiBi
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.shape[-1])
alibi_bias = self._get_alibi_bias(seq_len)
scores = scores + alibi_bias
return torch.softmax(scores, dim=-1) @ V参考
相关词条:Transformer数学基础,稀疏注意力与长度外推,Vision Transformer详解
Footnotes
-
Press et al., “Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation”, ICLR 2022 ↩