TransXSSM:统一旋转位置编码的混合Transformer-SSM架构

概述

TransXSSM是一种新型混合序列建模框架,通过**统一旋转位置编码(Unified Rotary Position Embedding, URope)**将Transformer的自注意力机制与状态空间模型(SSM)的线性复杂度建模能力有机融合。1

核心论文:arXiv:2506.095071
研究机构:香港科技大学(广州)

关键贡献

  • 提出URope统一位置编码,解决Transformer与SSM位置表示不一致问题
  • 在保持线性复杂度的同时捕捉长程依赖
  • 在LongBench基准上超越Pure Transformer和Pure SSM基线

1. 背景与动机

1.1 Transformer与SSM的互补性

特性TransformerSSM(Mamba)
计算复杂度
长程依赖全局注意力,但计算重选择性遗忘,可能丢失信息
位置感知依赖位置编码隐式位置编码
并行训练高效高效
推理效率低效(KV Cache大)高效(状态压缩)

1.2 现有混合方法的挑战

核心问题:Transformer和SSM对位置的处理方式不同:

  • Transformer:显式位置编码(绝对/相对/Rope)
  • SSM:隐式位置编码,难以直接融合

现有方法的问题

  1. 并行混合(如Mamba-Transformer):需要两套位置编码,增加参数量
  2. 串行混合(如Mamba-Hybrid):位置表示不统一,融合效果受限
  3. 交替混合:缺乏深层语义融合

1.3 TransXSSM的洞察

“Transformer和SSM的本质差异在于位置信息的表示方式,而非计算范式本身。”

核心洞察:通过统一旋转位置编码,可以在同一表示空间中同时支持注意力和状态空间操作。


2. 核心方法:URope

2.1 旋转位置编码回顾

标准RoPE(Rotary Position Embedding)将位置信息编码为旋转矩阵:

对于第个位置的查询向量

其中旋转矩阵:

优势

  • 相对位置信息通过内积自动编码
  • 无需额外的偏置项
  • 可扩展到高维

2.2 统一旋转位置编码(URope)

问题形式化

传统方法对Transformer和SSM使用不同的位置编码:

URope解决方案

将位置信息统一编码到旋转矩阵 中,SSM通过修改状态转移矩阵来编码位置:

2.3 数学形式化

定理(URope正确性)

为位置 的旋转矩阵, 为与位置 相关的查询向量。则:

对于SSM,状态更新满足:

其中

2.4 URope的物理意义

直觉解释

操作物理意义
将查询向量旋转到位置 的参考系
位置相关的状态转移(旋转坐标系中的动态)
位置相关的输入投影

核心优势

  • 位置信息在Transformer和SSM中一致表示
  • 无需额外的位置偏置
  • 可以无缝切换注意力模式和SSM模式

3. TransXSSM架构

3.1 整体结构

TransXSSM Block
│
├── 输入 X
├── LayerNorm
│
├── ┌─────────────────────────────────────────────┐
│  │ Transformer分支                               │
│  │  ├── QKV投影 + URope                         │
│  │  ├── Flash Attention                         │
│  │  └── 输出投影                                 │
│  └─────────────────────────────────────────────┘
│
├── Gate (可学习)
│
├── ┌─────────────────────────────────────────────┐
│  │ SSM分支                                      │
│  │  ├── 输入投影 + URope                        │
│  │  ├── 选择性SSM (Mamba-style)                 │
│  │  └── 输出投影                                │
│  └─────────────────────────────────────────────┘
│
├── 门控融合
│
└── 输出

3.2 融合机制

门控融合

其中 是可学习的门控权重。

3.3 位置编码统一

class URopeAttention(nn.Module):
    """URope注意力实现"""
    def __init__(self, d_model, n_heads, max_seq_len=4096):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # QKV投影
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # URope:统一的旋转位置编码
        self.rope = URope(d_k=self.d_k, max_seq_len=max_seq_len)
        
    def forward(self, x):
        B, N, D = x.shape
        
        # QKV投影
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # 应用URope(统一位置编码)
        Q = self.rope(Q)
        K = self.rope(K)
        
        # 注意力计算
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        
        return torch.matmul(attn, V)
 
 
class URopeSSM(nn.Module):
    """URope SSM实现"""
    def __init__(self, d_model, d_state=16):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # 输入投影
        self.x_proj = nn.Linear(d_model, d_state * 2 + 1, bias=False)
        
        # 状态矩阵
        self.A_log = nn.Parameter(torch.randn(d_model, d_state))
        self.D = nn.Parameter(torch.ones(d_model))
        
        # URope旋转
        self.rope = URope(d_k=d_model, max_seq_len=4096)
        
    def forward(self, x):
        B, N, D = x.shape
        dtype, device = x.dtype, x.device
        
        # 输入投影获取SSM参数
        x_dbl = self.x_proj(x)
        dt, B_proj, C_proj = x_dbl.split([1, self.d_state, self.d_state], dim=-1)
        
        # 应用URope旋转
        dt = self.rope.apply_rotary(dt)
        B_proj = self.rope.apply_rotary(B_proj)
        C_proj = self.rope.apply_rotary(C_proj)
        
        # 选择性扫描...
        # (省略具体实现细节)

3.4 计算复杂度分析

组件时间复杂度空间复杂度
Transformer分支
SSM分支
TransXSSM

其中 为头数, 为模型维度, 为状态维度。


4. 实验结果

4.1 主要结果

LongBench基准测试

模型平均NarrativeQAQasperMF-EnTriviaQA
Mamba28.338.224.132.548.1
Transformer29.140.125.330.252.4
TransXSSM31.242.327.834.153.2

4.2 消融实验

变体性能说明
基线(无URope)28.5独立位置编码
URope-Absolute29.8绝对位置
URope-Relative30.5相对位置
URope-Full31.2完整URope

5. PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class URope:
    """统一旋转位置编码"""
    def __init__(self, d_k, max_seq_len=4096):
        self.d_k = d_k
        self.max_seq_len = max_seq_len
        
        # 预计算旋转角度
        inv_freq = 1.0 / (10000 ** (torch.arange(0, d_k, 2).float() / d_k))
        t = torch.arange(max_seq_len).float()
        freqs = torch.einsum('i,j->ij', t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()
        self.sin_cached = emb.sin()
    
    def rotate_half(self, x):
        """将输入分成两半并旋转"""
        x1 = x[..., :x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2:]
        return torch.cat((-x2, x1), dim=-1)
    
    def apply_rotary(self, x, seq_len=None):
        """应用旋转"""
        if seq_len is None:
            seq_len = x.shape[1]
        cos = self.cos_cached[:seq_len].to(x.device)
        sin = self.sin_cached[:seq_len].to(x.device)
        return (x * cos.unsqueeze(-1)) + (self.rotate_half(x) * sin.unsqueeze(-1))
 
 
class TransXSSMBlock(nn.Module):
    """TransXSSM block实现"""
    def __init__(self, d_model, d_state=16, n_heads=8, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        
        # URope
        self.rope = URope(d_model, max_seq_len=8192)
        
        # Transformer分支
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.attn_norm = nn.LayerNorm(d_model)
        
        # SSM分支
        self.ssm = SelectiveSSM(d_model, d_state)
        self.ssm_norm = nn.LayerNorm(d_model)
        
        # 融合门控
        self.gate = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Sigmoid()
        )
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model)
        )
        
    def forward(self, x):
        residual = x
        
        # Transformer分支
        q = self.rope.apply_rotary(self.attn.W_q(x))
        k = self.rope.apply_rotary(self.attn.W_k(x))
        v = self.attn.W_v(x)
        attn_out, _ = self.attn(q, k, v)
        attn_out = self.attn_norm(attn_out + residual)
        
        # SSM分支
        ssm_out = self.ssm(x)
        ssm_out = self.ssm_norm(ssm_out + residual)
        
        # 门控融合
        g = self.gate(x)
        fused = g * attn_out + (1 - g) * ssm_out
        
        # FFN
        return self.ffn(fused) + fused
 
 
class SelectiveSSM(nn.Module):
    """选择性SSM(Mamba风格)"""
    def __init__(self, d_model, d_state=16):
        super().__init__()
        self.d_state = d_state
        self.d_inner = d_model + d_state * 2
        
        # 输入投影
        self.x_proj = nn.Linear(d_model, self.d_inner, bias=False)
        
        # 状态矩阵
        self.A_log = nn.Parameter(torch.randn(d_model, d_state))
        self.D = nn.Parameter(torch.ones(d_model))
        
    def forward(self, x):
        B, L, D = x.shape
        
        # 输入投影
        x_dbl = self.x_proj(x)
        dt, B_proj, C_proj = x_dbl.split([D, self.d_state, self.d_state], dim=-1)
        
        # 软dt投影
        dt = F.softplus(dt)
        
        # 选择性扫描(简化实现)
        # 完整实现需要并行前缀扫描
        A = -torch.exp(self.A_log)
        
        # 离散化
        dA = torch.exp(dt.unsqueeze(-1) * A)
        dB = dt.unsqueeze(-1) * B_proj.unsqueeze(-1)
        
        # 扫描
        h = torch.zeros(B, D, self.d_state, device=x.device)
        outputs = []
        
        for i in range(L):
            h = dA[:, i] * h + dB[:, i] * x[:, i:i+1].unsqueeze(-1)
            y = torch.einsum('bdn,bn->bd', h, C_proj[:, i])
            outputs.append(y)
        
        return torch.stack(outputs, dim=1) + self.D * x

6. 总结

核心贡献

  1. URope:统一旋转位置编码,使Transformer和SSM在相同位置空间中操作
  2. TransXSSM Block:无缝融合注意力和状态空间建模
  3. 门控融合:可学习的权重平衡两种建模方式

关键洞察

位置表示的统一是混合架构成功的关键。通过将位置信息编码为旋转矩阵的变换,TransXSSM实现了Transformer和SSM的深层融合。

局限与未来方向

  1. 计算开销:注意力分支仍需 计算
  2. 门控机制:可探索更动态的门控策略
  3. 长上下文:在超长序列上的性能待验证

参考资料

Footnotes

  1. Wu et al. (2025). TransXSSM: A Hybrid Transformer–State Space Model with Unified Rotary Position Embedding. arXiv:2506.09507. 2