ComplexFormer复数向量注意力
概述
ComplexFormer是NAACL 2025提出的Transformer改进方法,通过复数平面中的头特定旋转向量统一建模语义差异和位置差异。1
传统位置编码方法(如RoPE)通常对所有注意力头应用统一的相对位置调整,限制了表达能力。ComplexFormer允许每个头学习独特的策略来整合语义角度差和相对位置编码。
背景与动机
现有位置编码的局限
| 方法 | 语义-位置整合方式 | 头特异性 |
|---|---|---|
| 绝对位置编码 | 分别编码后相加 | 无 |
| 相对位置编码 | 统一偏置 | 无 |
| RoPE | 旋转角度统一 | 有限 |
| ComplexFormer | 头特定复数空间 | 完全 |
核心洞察
ComplexFormer的核心洞察是:语义信息和位置信息可以在复数平面中统一表示
- 实部:语义相似性
- 虚部:相对位置关系
- 模长:注意力强度
- 幅角:语义-位置平衡
核心方法
复数多头注意力(CMHA)
ComplexFormer引入两个关键创新:
1. 头特定Euler变换
将实值Query/Key投影转换为极坐标形式的复数向量:
其中 是头特定的投影矩阵。
class HeadwiseEulerTransform(nn.Module):
"""
头特定Euler变换:将实值向量转换为极坐标复数向量
"""
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_head = d_model // n_heads
# 每个头的独立投影
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
# 幅度和角度的独立学习
self.angle_net_q = nn.Sequential(
nn.Linear(d_model, n_heads * self.d_head),
nn.Tanh() # 限制在[-1, 1]
)
self.angle_net_k = nn.Sequential(
nn.Linear(d_model, n_heads * self.d_head),
nn.Tanh()
)
def to_polar(self, x, angle_net):
# x: [B, N, D]
B, N, D = x.shape
# 线性投影
x_proj = self.q_proj(x) # [B, N, D]
# 计算角度
angles = angle_net(x) # [B, N, n_heads * d_head]
angles = angles.view(B, N, self.n_heads, self.d_head)
# 重塑为复数形式:实部=cos, 虚部=sin
magnitude = x_proj.view(B, N, self.n_heads, self.d_head)
real = magnitude * angles.cos()
imag = magnitude * angles.sin()
# 返回复数张量(实部在前,虚部在后拼接)
return torch.stack([real, imag], dim=-1) # [B, N, H, D, 2]2. 自适应差分旋转机制
关键创新:允许每个头学习区分策略整合语义角度差和相对位置编码:
其中:
- :语义角度差
- :相对位置编码
- :头特定的可学习权重
class AdaptiveDifferentialRotation(nn.Module):
"""
自适应差分旋转:头特定的语义-位置平衡
"""
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_head = d_model // n_heads
# 语义角度权重(头特定)
self.alpha = nn.Parameter(torch.ones(n_heads, 1, 1))
# 位置编码权重(头特定)
self.beta = nn.Parameter(torch.ones(n_heads, 1, 1))
# 相对位置编码
self.max_pos = 1024
self.rel_pos_emb = nn.Embedding(2 * self.max_pos + 1, d_model)
def forward(self, q_complex, k_complex, positions):
B, N, H, D, _ = q_complex.shape
# 计算语义角度差
# q_complex: [B, N, H, D, 2] (real, imag)
q_real, q_imag = q_complex[..., 0], q_complex[..., 1]
k_real, k_imag = k_complex[..., 0], k_complex[..., 1]
# 复数点积:q * conj(k)
sem_angle_diff = q_real * k_real + q_imag * k_imag # 语义相似性
# 计算相对位置
positions = positions.unsqueeze(1) # [B, 1, N]
rel_pos = positions - positions.transpose(1, 2) # [B, N, N]
rel_pos = rel_pos.clamp(-self.max_pos, self.max_pos) + self.max_pos
rel_emb = self.rel_pos_emb(rel_pos) # [B, N, N, D]
# 展平头维度
sem_angle_diff = sem_angle_diff.permute(0, 3, 1, 2) # [B, D, N, N]
rel_emb = rel_emb.permute(0, 3, 1, 2) # [B, D, N, N]
# 自适应组合
combined = self.alpha.sigmoid() * sem_angle_diff + \
self.beta.sigmoid() * rel_emb.sum(dim=1, keepdim=True)
return combined # [B, 1, N, N]完整ComplexFormer层
class ComplexFormerLayer(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.cmha = ComplexMultiHeadAttention(d_model, n_heads)
self.ffn = FeedForward(d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, positions=None):
B, N, D = x.shape
if positions is None:
positions = torch.arange(N, device=x.device).unsqueeze(0).expand(B, -1)
# Complex MHA
attn_out = self.cmha(x, positions)
x = self.norm1(x + attn_out)
# FFN
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
return x理论分析
表达力提升
ComplexFormer相比标准RoPE的表达力提升:
| 特性 | RoPE | ComplexFormer |
|---|---|---|
| 位置编码 | 统一旋转 | 头特定 |
| 语义-位置交互 | 固定比例 | 自适应 |
| 独立头数 | 受限 | 完全独立 |
| 相对距离感知 | 全局 | 全局+局部 |
复杂度分析
| 操作 | 标准MHA | ComplexFormer |
|---|---|---|
| 参数量 | ||
| 注意力计算 | ||
| 额外开销 | - | 约5-10% |
实验结果
语言建模
在WikiText-103上的困惑度:
| 模型 | 参数量 | 困惑度 | 相对提升 |
|---|---|---|---|
| RoPE-Transformer | 247M | 16.4 | - |
| ComplexFormer | 247M | 15.2 | 7.3% |
生成质量
在多种任务上的生成困惑度对比:
| 任务 | RoPE | ComplexFormer | 提升 |
|---|---|---|---|
| WikiText-103 | 16.4 | 15.2 | 7.3% |
| 代码生成 | 23.1 | 21.8 | 5.6% |
| 数学推理 | 34.2 | 31.9 | 6.7% |
长上下文性能
| 序列长度 | RoPE | ComplexFormer |
|---|---|---|
| 4K | 18.2 | 17.6 |
| 16K | 19.8 | 18.3 |
| 64K | 24.3 | 20.1 |
ComplexFormer在长上下文上优势更显著。
与其他位置编码的对比
| 方法 | 头特异性 | 语义-位置交互 | 复杂度 |
|---|---|---|---|
| Sinusoidal | ✗ | 加法 | O(n) |
| RoPE | 有限 | 乘法 | O(1) |
| ALiBi | ✗ | 线性偏置 | O(1) |
| ComplexFormer | 完全 | 自适应 | O(1) |
应用建议
推荐场景
- 长上下文任务:ComplexFormer在64K+序列上优势明显
- 代码生成:变量作用域和位置关系需要精细建模
- 数学推理:精确的位置推理有助于计算
配置建议
# 小模型
config_small = {
"n_heads": 8,
"d_head": 64,
"angle_net_hidden": 32,
}
# 大模型
config_large = {
"n_heads": 32,
"d_head": 128,
"angle_net_hidden": 128,
}参考资料
相关链接
Footnotes
-
“ComplexFormer: Disruptively Advancing Transformer Inference Ability via Head-Specific Complex Vector Attention” NAACL 2025. arXiv:2505.10222 ↩