概述
TransMamba是一种序列级(Sequence-Level)混合框架,通过共享参数矩阵和Memory Converter机制,在同一模型中统一了Transformer和Mamba两种范式。1
核心创新:
- 参数共享:QKV与CBx共享同一组参数
- Memory Converter:无损信息转换机制
- TransPoint调度:动态切换Attention/SSM模式
- 训练效率:相比Transformer提升25%
1. 设计动机
1.1 现有混合方法的局限
| 方法 | 代表模型 | 局限 |
|---|---|---|
| 并行混合 | Jamba | 需要独立的两套参数 |
| 串行混合 | Hybrid | 参数量增加 |
| 交替混合 | Mamba-Hybrid | 缺乏深层融合 |
1.2 TransMamba的洞察
TransMamba的核心洞察:Transformer和Mamba本质上是同一计算图的不同分解。
- Transformer:,其中 来自输入的独立投影
- Mamba:,其中状态由输入通过 投影构造
参数共享的可能性:
2. 架构设计
2.1 整体结构
TransMamba的层结构:
TransMamba Block
├── Input Norm
├── Attention Mode
│ ├── QKV Projection (共享)
│ ├── RoPE
│ ├── Flash Attention
│ └── Output Projection
├── Memory Converter (← 关键)
├── SSM Mode
│ ├── CBx Projection (共享)
│ ├── SSM Discretization
│ ├── SSD Scan
│ └── Output Projection
└── Output Norm
2.2 参数共享机制
共享策略:同一组参数在不同模式下复用
class TransMambaBlock(nn.Module):
def __init__(self, d_model, d_state=128):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# 共享的输入投影
# Attention模式: Q, K, V
# SSM模式: C, B, x
self.shared_proj = nn.Linear(d_model, d_model * 3)
# 独立参数
self.rope = RoPE(d_model)
self.ssm_gate = nn.Parameter(torch.ones(d_model))
def forward(self, x, mode='attention'):
# 共享投影
qkv = self.shared_proj(x) # [B, L, 3D]
q, k, v = qkv.chunk(3, dim=-1)
if mode == 'attention':
# Attention模式
q = self.rope(q)
k = self.rope(k)
out = self.attention(q, k, v)
else:
# SSM模式
c = q[:, :, :self.d_state] # Q → C
b = k[:, :, :self.d_state] # K → B
x_ssm = v # V → x
out = self.ssm_scan(x_ssm, c, b)
return out2.3 Memory Converter
问题:QKV投影和CBx投影的维度空间不同
- QKV空间: → 需要分离Q, K, V
- CBx空间: → 需要构造状态
解决方案:Memory Converter进行无损信息转换
class MemoryConverter(nn.Module):
def __init__(self, d_model, d_state):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# 维度映射
self.q_to_c = nn.Linear(d_model, d_state)
self.k_to_b = nn.Linear(d_model, d_state)
self.v_adapter = nn.Linear(d_model, d_model)
# 状态门控
self.state_gate = nn.Parameter(torch.ones(d_state))
def forward(self, q, k, v):
# Q → C (Attention → SSM)
c = self.q_to_c(q) * torch.sigmoid(self.state_gate)
# K → B
b = self.k_to_b(k) * torch.sigmoid(self.state_gate)
# V → x (适配维度)
x = self.v_adapter(v)
return c, b, x3. TransPoint调度策略
3.1 模式切换机制
TransMamba支持动态模式切换:
class TransPointScheduler:
def __init__(self, num_layers, cycle_length=8, offset_start=2):
"""
Args:
cycle_length: 切换周期 (论文建议8)
offset_start: SSM开始的偏移 (论文建议2)
"""
self.num_layers = num_layers
self.cycle_length = cycle_length
self.offset_start = offset_start
def get_mode(self, layer_idx):
"""返回当前层应该使用的模式"""
adjusted_idx = layer_idx + self.offset_start
if adjusted_idx % self.cycle_length == 0:
return 'ssm'
else:
return 'attention'
def create_schedule(self):
"""生成完整的模式调度表"""
schedule = []
for i in range(self.num_layers):
mode = self.get_mode(i)
schedule.append({
'layer': i,
'mode': mode,
'memory_converter': mode == 'attention' # 只在切换时使用
})
return schedule3.2 调度模式分析
| 周期长度 | Attention层占比 | SSM层占比 | 性能 |
|---|---|---|---|
| 4 | 75% | 25% | 接近Transformer |
| 8 | 87.5% | 12.5% | 最优 |
| 16 | 93.75% | 6.25% | 略下降 |
| 连续Attention | 100% | 0% | Transformer基线 |
3.3 细粒度调度
TransMamba还支持Token级别的细粒度调度:
class FineGrainedTransMamba(nn.Module):
def forward(self, x, attention_mask=None):
B, L, D = x.shape
# Token级别的模式预测
mode_logits = self.mode_predictor(x) # [B, L, 1]
mode_probs = torch.softmax(mode_logits, dim=1)
# 软模式混合
soft_mode = mode_probs.squeeze(-1) # [B, L]
# 分离处理
attn_out = self.attention_branch(x)
ssm_out = self.ssm_branch(x)
# 加权融合
weights = torch.stack([1 - soft_mode, soft_mode], dim=-1)
outputs = torch.stack([attn_out, ssm_out], dim=-1)
out = (outputs * weights.unsqueeze(-1)).sum(dim=-1)
return out4. 实验结果
4.1 标准基准测试
| 模型 | 400M参数 | 1.5B参数 | ||
|---|---|---|---|---|
| PPL | 加速 | PPL | 加速 | |
| Transformer | 18.2 | 1.0x | 14.1 | 1.0x |
| Mamba-2 | 17.8 | 2.1x | 13.6 | 2.3x |
| Hybrid | 17.5 | 1.5x | 13.3 | 1.6x |
| TransMamba | 17.1 | 1.8x | 13.0 | 1.9x |
4.2 长文本基准测试
在LongBench-v2上的结果:
| 任务 | TransMamba | Hybrid | Delta |
|---|---|---|---|
| NarrativeQA | 38.2% | 36.8% | +1.4 |
| QMSum | 23.4% | 22.1% | +1.3 |
| MultiFieldQA | 44.2% | 42.9% | +1.3 |
| 平均 | 35.3% | 33.9% | +1.4 |
4.3 训练效率
| 指标 | Transformer | TransMamba | 改善 |
|---|---|---|---|
| FLOPs (T=8K) | 100% | 43.6% | -56.4% |
| 训练时间 | 100% | 75% | -25% |
| 显存占用 | 100% | 78% | -22% |
5. 与其他方法的对比
5.1 参数效率
| 方法 | 参数量 | 有效参数利用率 |
|---|---|---|
| Transformer | 1.0x | 100% |
| Mamba | 1.0x | ~95% |
| Hybrid | 1.1x | ~90% |
| TransMamba | 1.0x | ~98% |
5.2 架构对比
| 特性 | Transformer | Mamba | Hybrid | TransMamba |
|---|---|---|---|---|
| 注意力 | ✅ | ❌ | 部分 | 动态 |
| SSM | ❌ | ✅ | 部分 | 动态 |
| 参数共享 | ❌ | ❌ | ❌ | ✅ |
| 无损转换 | - | - | ❌ | ✅ |
| 动态调度 | ❌ | ❌ | 固定 | ✅ |
6. 实现细节
6.1 代码结构
# TransMamba实现结构
class TransMambaModel(nn.Module):
def __init__(self, config):
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_dim)
self.layers = nn.ModuleList([
TransMambaBlock(config)
for _ in range(config.num_layers)
])
self.norm = nn.LayerNorm(config.hidden_dim)
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size)
def forward(self, input_ids, mode_schedule=None):
x = self.embeddings(input_ids)
for i, layer in enumerate(self.layers):
mode = mode_schedule[i] if mode_schedule else 'attention'
x = layer(x, mode=mode)
return self.norm(x)6.2 训练配置
config = {
"model_type": "transmamba",
"hidden_dim": 2048,
"intermediate_dim": 5632,
"num_layers": 24,
"d_state": 128,
"transpoint_cycle": 8,
"transpoint_offset": 2,
"vocab_size": 50257,
}7. 总结
TransMamba通过序列级统一实现了Transformer和Mamba的真正融合:
- 参数共享:QKV↔CBx复用,减少冗余
- 无损转换:Memory Converter保持信息流
- 动态调度:TransPoint实现最优模式切换
- 效率提升:25%训练速度改善
TransMamba代表了混合架构设计的新方向:不是简单堆叠,而是深度统一。
参考资料
相关文档:[[mamba-2-state-space-duality-deep-theory)、[mamba-2-hybrid-architecture-design)、[hybrid-ssm-transformer]]
Footnotes
-
Li, Y. et al. (2025). TransMamba: A Sequence-Level Hybrid Transformer-Mamba. arXiv:2503.24067. ↩