ComplexFormer复数向量注意力

概述

ComplexFormer是NAACL 2025提出的Transformer改进方法,通过复数平面中的头特定旋转向量统一建模语义差异和位置差异。1

传统位置编码方法(如RoPE)通常对所有注意力头应用统一的相对位置调整,限制了表达能力。ComplexFormer允许每个头学习独特的策略来整合语义角度差和相对位置编码。

背景与动机

现有位置编码的局限

方法语义-位置整合方式头特异性
绝对位置编码分别编码后相加
相对位置编码统一偏置
RoPE旋转角度统一有限
ComplexFormer头特定复数空间完全

核心洞察

ComplexFormer的核心洞察是:语义信息和位置信息可以在复数平面中统一表示

  • 实部:语义相似性
  • 虚部:相对位置关系
  • 模长:注意力强度
  • 幅角:语义-位置平衡

核心方法

复数多头注意力(CMHA)

ComplexFormer引入两个关键创新:

1. 头特定Euler变换

将实值Query/Key投影转换为极坐标形式的复数向量:

其中 是头特定的投影矩阵。

class HeadwiseEulerTransform(nn.Module):
    """
    头特定Euler变换:将实值向量转换为极坐标复数向量
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        # 每个头的独立投影
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        
        # 幅度和角度的独立学习
        self.angle_net_q = nn.Sequential(
            nn.Linear(d_model, n_heads * self.d_head),
            nn.Tanh()  # 限制在[-1, 1]
        )
        self.angle_net_k = nn.Sequential(
            nn.Linear(d_model, n_heads * self.d_head),
            nn.Tanh()
        )
        
    def to_polar(self, x, angle_net):
        # x: [B, N, D]
        B, N, D = x.shape
        
        # 线性投影
        x_proj = self.q_proj(x)  # [B, N, D]
        
        # 计算角度
        angles = angle_net(x)  # [B, N, n_heads * d_head]
        angles = angles.view(B, N, self.n_heads, self.d_head)
        
        # 重塑为复数形式:实部=cos, 虚部=sin
        magnitude = x_proj.view(B, N, self.n_heads, self.d_head)
        real = magnitude * angles.cos()
        imag = magnitude * angles.sin()
        
        # 返回复数张量(实部在前,虚部在后拼接)
        return torch.stack([real, imag], dim=-1)  # [B, N, H, D, 2]

2. 自适应差分旋转机制

关键创新:允许每个头学习区分策略整合语义角度差相对位置编码

其中:

  • :语义角度差
  • :相对位置编码
  • :头特定的可学习权重
class AdaptiveDifferentialRotation(nn.Module):
    """
    自适应差分旋转:头特定的语义-位置平衡
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        # 语义角度权重(头特定)
        self.alpha = nn.Parameter(torch.ones(n_heads, 1, 1))
        
        # 位置编码权重(头特定)
        self.beta = nn.Parameter(torch.ones(n_heads, 1, 1))
        
        # 相对位置编码
        self.max_pos = 1024
        self.rel_pos_emb = nn.Embedding(2 * self.max_pos + 1, d_model)
        
    def forward(self, q_complex, k_complex, positions):
        B, N, H, D, _ = q_complex.shape
        
        # 计算语义角度差
        # q_complex: [B, N, H, D, 2] (real, imag)
        q_real, q_imag = q_complex[..., 0], q_complex[..., 1]
        k_real, k_imag = k_complex[..., 0], k_complex[..., 1]
        
        # 复数点积:q * conj(k)
        sem_angle_diff = q_real * k_real + q_imag * k_imag  # 语义相似性
        
        # 计算相对位置
        positions = positions.unsqueeze(1)  # [B, 1, N]
        rel_pos = positions - positions.transpose(1, 2)  # [B, N, N]
        rel_pos = rel_pos.clamp(-self.max_pos, self.max_pos) + self.max_pos
        rel_emb = self.rel_pos_emb(rel_pos)  # [B, N, N, D]
        
        # 展平头维度
        sem_angle_diff = sem_angle_diff.permute(0, 3, 1, 2)  # [B, D, N, N]
        rel_emb = rel_emb.permute(0, 3, 1, 2)  # [B, D, N, N]
        
        # 自适应组合
        combined = self.alpha.sigmoid() * sem_angle_diff + \
                   self.beta.sigmoid() * rel_emb.sum(dim=1, keepdim=True)
        
        return combined  # [B, 1, N, N]

完整ComplexFormer层

class ComplexFormerLayer(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.cmha = ComplexMultiHeadAttention(d_model, n_heads)
        self.ffn = FeedForward(d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
    def forward(self, x, positions=None):
        B, N, D = x.shape
        if positions is None:
            positions = torch.arange(N, device=x.device).unsqueeze(0).expand(B, -1)
        
        # Complex MHA
        attn_out = self.cmha(x, positions)
        x = self.norm1(x + attn_out)
        
        # FFN
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        
        return x

理论分析

表达力提升

ComplexFormer相比标准RoPE的表达力提升:

特性RoPEComplexFormer
位置编码统一旋转头特定
语义-位置交互固定比例自适应
独立头数受限完全独立
相对距离感知全局全局+局部

复杂度分析

操作标准MHAComplexFormer
参数量
注意力计算
额外开销-约5-10%

实验结果

语言建模

在WikiText-103上的困惑度:

模型参数量困惑度相对提升
RoPE-Transformer247M16.4-
ComplexFormer247M15.27.3%

生成质量

在多种任务上的生成困惑度对比:

任务RoPEComplexFormer提升
WikiText-10316.415.27.3%
代码生成23.121.85.6%
数学推理34.231.96.7%

长上下文性能

序列长度RoPEComplexFormer
4K18.217.6
16K19.818.3
64K24.320.1

ComplexFormer在长上下文上优势更显著。

与其他位置编码的对比

方法头特异性语义-位置交互复杂度
Sinusoidal加法O(n)
RoPE有限乘法O(1)
ALiBi线性偏置O(1)
ComplexFormer完全自适应O(1)

应用建议

推荐场景

  1. 长上下文任务:ComplexFormer在64K+序列上优势明显
  2. 代码生成:变量作用域和位置关系需要精细建模
  3. 数学推理:精确的位置推理有助于计算

配置建议

# 小模型
config_small = {
    "n_heads": 8,
    "d_head": 64,
    "angle_net_hidden": 32,
}
 
# 大模型
config_large = {
    "n_heads": 32,
    "d_head": 128,
    "angle_net_hidden": 128,
}

参考资料

相关链接

Footnotes

  1. “ComplexFormer: Disruptively Advancing Transformer Inference Ability via Head-Specific Complex Vector Attention” NAACL 2025. arXiv:2505.10222