TransXSSM:统一旋转位置编码的混合Transformer-SSM架构
概述
TransXSSM是一种新型混合序列建模框架,通过**统一旋转位置编码(Unified Rotary Position Embedding, URope)**将Transformer的自注意力机制与状态空间模型(SSM)的线性复杂度建模能力有机融合。1
核心论文:arXiv:2506.095071
研究机构:香港科技大学(广州)
关键贡献:
- 提出URope统一位置编码,解决Transformer与SSM位置表示不一致问题
- 在保持线性复杂度的同时捕捉长程依赖
- 在LongBench基准上超越Pure Transformer和Pure SSM基线
1. 背景与动机
1.1 Transformer与SSM的互补性
| 特性 | Transformer | SSM(Mamba) |
|---|---|---|
| 计算复杂度 | ||
| 长程依赖 | 全局注意力,但计算重 | 选择性遗忘,可能丢失信息 |
| 位置感知 | 依赖位置编码 | 隐式位置编码 |
| 并行训练 | 高效 | 高效 |
| 推理效率 | 低效(KV Cache大) | 高效(状态压缩) |
1.2 现有混合方法的挑战
核心问题:Transformer和SSM对位置的处理方式不同:
- Transformer:显式位置编码(绝对/相对/Rope)
- SSM:隐式位置编码,难以直接融合
现有方法的问题:
- 并行混合(如Mamba-Transformer):需要两套位置编码,增加参数量
- 串行混合(如Mamba-Hybrid):位置表示不统一,融合效果受限
- 交替混合:缺乏深层语义融合
1.3 TransXSSM的洞察
“Transformer和SSM的本质差异在于位置信息的表示方式,而非计算范式本身。”
核心洞察:通过统一旋转位置编码,可以在同一表示空间中同时支持注意力和状态空间操作。
2. 核心方法:URope
2.1 旋转位置编码回顾
标准RoPE(Rotary Position Embedding)将位置信息编码为旋转矩阵:
对于第个位置的查询向量 :
其中旋转矩阵:
优势:
- 相对位置信息通过内积自动编码
- 无需额外的偏置项
- 可扩展到高维
2.2 统一旋转位置编码(URope)
问题形式化:
传统方法对Transformer和SSM使用不同的位置编码:
URope解决方案:
将位置信息统一编码到旋转矩阵 中,SSM通过修改状态转移矩阵来编码位置:
2.3 数学形式化
定理(URope正确性):
设 为位置 的旋转矩阵, 为与位置 相关的查询向量。则:
对于SSM,状态更新满足:
其中 。
2.4 URope的物理意义
直觉解释:
| 操作 | 物理意义 |
|---|---|
| 将查询向量旋转到位置 的参考系 | |
| 位置相关的状态转移(旋转坐标系中的动态) | |
| 位置相关的输入投影 |
核心优势:
- 位置信息在Transformer和SSM中一致表示
- 无需额外的位置偏置
- 可以无缝切换注意力模式和SSM模式
3. TransXSSM架构
3.1 整体结构
TransXSSM Block
│
├── 输入 X
├── LayerNorm
│
├── ┌─────────────────────────────────────────────┐
│ │ Transformer分支 │
│ │ ├── QKV投影 + URope │
│ │ ├── Flash Attention │
│ │ └── 输出投影 │
│ └─────────────────────────────────────────────┘
│
├── Gate (可学习)
│
├── ┌─────────────────────────────────────────────┐
│ │ SSM分支 │
│ │ ├── 输入投影 + URope │
│ │ ├── 选择性SSM (Mamba-style) │
│ │ └── 输出投影 │
│ └─────────────────────────────────────────────┘
│
├── 门控融合
│
└── 输出
3.2 融合机制
门控融合:
其中 是可学习的门控权重。
3.3 位置编码统一
class URopeAttention(nn.Module):
"""URope注意力实现"""
def __init__(self, d_model, n_heads, max_seq_len=4096):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# QKV投影
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# URope:统一的旋转位置编码
self.rope = URope(d_k=self.d_k, max_seq_len=max_seq_len)
def forward(self, x):
B, N, D = x.shape
# QKV投影
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
# 应用URope(统一位置编码)
Q = self.rope(Q)
K = self.rope(K)
# 注意力计算
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = F.softmax(scores, dim=-1)
return torch.matmul(attn, V)
class URopeSSM(nn.Module):
"""URope SSM实现"""
def __init__(self, d_model, d_state=16):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# 输入投影
self.x_proj = nn.Linear(d_model, d_state * 2 + 1, bias=False)
# 状态矩阵
self.A_log = nn.Parameter(torch.randn(d_model, d_state))
self.D = nn.Parameter(torch.ones(d_model))
# URope旋转
self.rope = URope(d_k=d_model, max_seq_len=4096)
def forward(self, x):
B, N, D = x.shape
dtype, device = x.dtype, x.device
# 输入投影获取SSM参数
x_dbl = self.x_proj(x)
dt, B_proj, C_proj = x_dbl.split([1, self.d_state, self.d_state], dim=-1)
# 应用URope旋转
dt = self.rope.apply_rotary(dt)
B_proj = self.rope.apply_rotary(B_proj)
C_proj = self.rope.apply_rotary(C_proj)
# 选择性扫描...
# (省略具体实现细节)3.4 计算复杂度分析
| 组件 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| Transformer分支 | ||
| SSM分支 | ||
| TransXSSM |
其中 为头数, 为模型维度, 为状态维度。
4. 实验结果
4.1 主要结果
LongBench基准测试:
| 模型 | 平均 | NarrativeQA | Qasper | MF-En | TriviaQA |
|---|---|---|---|---|---|
| Mamba | 28.3 | 38.2 | 24.1 | 32.5 | 48.1 |
| Transformer | 29.1 | 40.1 | 25.3 | 30.2 | 52.4 |
| TransXSSM | 31.2 | 42.3 | 27.8 | 34.1 | 53.2 |
4.2 消融实验
| 变体 | 性能 | 说明 |
|---|---|---|
| 基线(无URope) | 28.5 | 独立位置编码 |
| URope-Absolute | 29.8 | 绝对位置 |
| URope-Relative | 30.5 | 相对位置 |
| URope-Full | 31.2 | 完整URope |
5. PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class URope:
"""统一旋转位置编码"""
def __init__(self, d_k, max_seq_len=4096):
self.d_k = d_k
self.max_seq_len = max_seq_len
# 预计算旋转角度
inv_freq = 1.0 / (10000 ** (torch.arange(0, d_k, 2).float() / d_k))
t = torch.arange(max_seq_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
def rotate_half(self, x):
"""将输入分成两半并旋转"""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary(self, x, seq_len=None):
"""应用旋转"""
if seq_len is None:
seq_len = x.shape[1]
cos = self.cos_cached[:seq_len].to(x.device)
sin = self.sin_cached[:seq_len].to(x.device)
return (x * cos.unsqueeze(-1)) + (self.rotate_half(x) * sin.unsqueeze(-1))
class TransXSSMBlock(nn.Module):
"""TransXSSM block实现"""
def __init__(self, d_model, d_state=16, n_heads=8, dropout=0.1):
super().__init__()
self.d_model = d_model
# URope
self.rope = URope(d_model, max_seq_len=8192)
# Transformer分支
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.attn_norm = nn.LayerNorm(d_model)
# SSM分支
self.ssm = SelectiveSSM(d_model, d_state)
self.ssm_norm = nn.LayerNorm(d_model)
# 融合门控
self.gate = nn.Sequential(
nn.Linear(d_model, d_model),
nn.Sigmoid()
)
# FFN
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_model * 4, d_model)
)
def forward(self, x):
residual = x
# Transformer分支
q = self.rope.apply_rotary(self.attn.W_q(x))
k = self.rope.apply_rotary(self.attn.W_k(x))
v = self.attn.W_v(x)
attn_out, _ = self.attn(q, k, v)
attn_out = self.attn_norm(attn_out + residual)
# SSM分支
ssm_out = self.ssm(x)
ssm_out = self.ssm_norm(ssm_out + residual)
# 门控融合
g = self.gate(x)
fused = g * attn_out + (1 - g) * ssm_out
# FFN
return self.ffn(fused) + fused
class SelectiveSSM(nn.Module):
"""选择性SSM(Mamba风格)"""
def __init__(self, d_model, d_state=16):
super().__init__()
self.d_state = d_state
self.d_inner = d_model + d_state * 2
# 输入投影
self.x_proj = nn.Linear(d_model, self.d_inner, bias=False)
# 状态矩阵
self.A_log = nn.Parameter(torch.randn(d_model, d_state))
self.D = nn.Parameter(torch.ones(d_model))
def forward(self, x):
B, L, D = x.shape
# 输入投影
x_dbl = self.x_proj(x)
dt, B_proj, C_proj = x_dbl.split([D, self.d_state, self.d_state], dim=-1)
# 软dt投影
dt = F.softplus(dt)
# 选择性扫描(简化实现)
# 完整实现需要并行前缀扫描
A = -torch.exp(self.A_log)
# 离散化
dA = torch.exp(dt.unsqueeze(-1) * A)
dB = dt.unsqueeze(-1) * B_proj.unsqueeze(-1)
# 扫描
h = torch.zeros(B, D, self.d_state, device=x.device)
outputs = []
for i in range(L):
h = dA[:, i] * h + dB[:, i] * x[:, i:i+1].unsqueeze(-1)
y = torch.einsum('bdn,bn->bd', h, C_proj[:, i])
outputs.append(y)
return torch.stack(outputs, dim=1) + self.D * x6. 总结
核心贡献
- URope:统一旋转位置编码,使Transformer和SSM在相同位置空间中操作
- TransXSSM Block:无缝融合注意力和状态空间建模
- 门控融合:可学习的权重平衡两种建模方式
关键洞察
位置表示的统一是混合架构成功的关键。通过将位置信息编码为旋转矩阵的变换,TransXSSM实现了Transformer和SSM的深层融合。
局限与未来方向
- 计算开销:注意力分支仍需 计算
- 门控机制:可探索更动态的门控策略
- 长上下文:在超长序列上的性能待验证