概述

TransMamba是一种序列级(Sequence-Level)混合框架,通过共享参数矩阵和Memory Converter机制,在同一模型中统一了Transformer和Mamba两种范式。1

核心创新:

  1. 参数共享:QKV与CBx共享同一组参数
  2. Memory Converter:无损信息转换机制
  3. TransPoint调度:动态切换Attention/SSM模式
  4. 训练效率:相比Transformer提升25%

1. 设计动机

1.1 现有混合方法的局限

方法代表模型局限
并行混合Jamba需要独立的两套参数
串行混合Hybrid参数量增加
交替混合Mamba-Hybrid缺乏深层融合

1.2 TransMamba的洞察

TransMamba的核心洞察:Transformer和Mamba本质上是同一计算图的不同分解

  • Transformer,其中 来自输入的独立投影
  • Mamba,其中状态由输入通过 投影构造

参数共享的可能性


2. 架构设计

2.1 整体结构

TransMamba的层结构:

TransMamba Block
├── Input Norm
├── Attention Mode
│   ├── QKV Projection (共享)
│   ├── RoPE
│   ├── Flash Attention
│   └── Output Projection
├── Memory Converter (← 关键)
├── SSM Mode
│   ├── CBx Projection (共享)
│   ├── SSM Discretization
│   ├── SSD Scan
│   └── Output Projection
└── Output Norm

2.2 参数共享机制

共享策略:同一组参数在不同模式下复用

class TransMambaBlock(nn.Module):
    def __init__(self, d_model, d_state=128):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # 共享的输入投影
        # Attention模式: Q, K, V
        # SSM模式: C, B, x
        self.shared_proj = nn.Linear(d_model, d_model * 3)
        
        # 独立参数
        self.rope = RoPE(d_model)
        self.ssm_gate = nn.Parameter(torch.ones(d_model))
        
    def forward(self, x, mode='attention'):
        # 共享投影
        qkv = self.shared_proj(x)  # [B, L, 3D]
        q, k, v = qkv.chunk(3, dim=-1)
        
        if mode == 'attention':
            # Attention模式
            q = self.rope(q)
            k = self.rope(k)
            out = self.attention(q, k, v)
        else:
            # SSM模式
            c = q[:, :, :self.d_state]  # Q → C
            b = k[:, :, :self.d_state]  # K → B
            x_ssm = v                    # V → x
            out = self.ssm_scan(x_ssm, c, b)
        
        return out

2.3 Memory Converter

问题:QKV投影和CBx投影的维度空间不同

  • QKV空间: → 需要分离Q, K, V
  • CBx空间: → 需要构造状态

解决方案:Memory Converter进行无损信息转换

class MemoryConverter(nn.Module):
    def __init__(self, d_model, d_state):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # 维度映射
        self.q_to_c = nn.Linear(d_model, d_state)
        self.k_to_b = nn.Linear(d_model, d_state)
        self.v_adapter = nn.Linear(d_model, d_model)
        
        # 状态门控
        self.state_gate = nn.Parameter(torch.ones(d_state))
        
    def forward(self, q, k, v):
        # Q → C (Attention → SSM)
        c = self.q_to_c(q) * torch.sigmoid(self.state_gate)
        
        # K → B
        b = self.k_to_b(k) * torch.sigmoid(self.state_gate)
        
        # V → x (适配维度)
        x = self.v_adapter(v)
        
        return c, b, x

3. TransPoint调度策略

3.1 模式切换机制

TransMamba支持动态模式切换

class TransPointScheduler:
    def __init__(self, num_layers, cycle_length=8, offset_start=2):
        """
        Args:
            cycle_length: 切换周期 (论文建议8)
            offset_start: SSM开始的偏移 (论文建议2)
        """
        self.num_layers = num_layers
        self.cycle_length = cycle_length
        self.offset_start = offset_start
        
    def get_mode(self, layer_idx):
        """返回当前层应该使用的模式"""
        adjusted_idx = layer_idx + self.offset_start
        
        if adjusted_idx % self.cycle_length == 0:
            return 'ssm'
        else:
            return 'attention'
    
    def create_schedule(self):
        """生成完整的模式调度表"""
        schedule = []
        for i in range(self.num_layers):
            mode = self.get_mode(i)
            schedule.append({
                'layer': i,
                'mode': mode,
                'memory_converter': mode == 'attention'  # 只在切换时使用
            })
        return schedule

3.2 调度模式分析

周期长度Attention层占比SSM层占比性能
475%25%接近Transformer
887.5%12.5%最优
1693.75%6.25%略下降
连续Attention100%0%Transformer基线

3.3 细粒度调度

TransMamba还支持Token级别的细粒度调度

class FineGrainedTransMamba(nn.Module):
    def forward(self, x, attention_mask=None):
        B, L, D = x.shape
        
        # Token级别的模式预测
        mode_logits = self.mode_predictor(x)  # [B, L, 1]
        mode_probs = torch.softmax(mode_logits, dim=1)
        
        # 软模式混合
        soft_mode = mode_probs.squeeze(-1)  # [B, L]
        
        # 分离处理
        attn_out = self.attention_branch(x)
        ssm_out = self.ssm_branch(x)
        
        # 加权融合
        weights = torch.stack([1 - soft_mode, soft_mode], dim=-1)
        outputs = torch.stack([attn_out, ssm_out], dim=-1)
        out = (outputs * weights.unsqueeze(-1)).sum(dim=-1)
        
        return out

4. 实验结果

4.1 标准基准测试

模型400M参数1.5B参数
PPL加速PPL加速
Transformer18.21.0x14.11.0x
Mamba-217.82.1x13.62.3x
Hybrid17.51.5x13.31.6x
TransMamba17.11.8x13.01.9x

4.2 长文本基准测试

在LongBench-v2上的结果:

任务TransMambaHybridDelta
NarrativeQA38.2%36.8%+1.4
QMSum23.4%22.1%+1.3
MultiFieldQA44.2%42.9%+1.3
平均35.3%33.9%+1.4

4.3 训练效率

指标TransformerTransMamba改善
FLOPs (T=8K)100%43.6%-56.4%
训练时间100%75%-25%
显存占用100%78%-22%

5. 与其他方法的对比

5.1 参数效率

方法参数量有效参数利用率
Transformer1.0x100%
Mamba1.0x~95%
Hybrid1.1x~90%
TransMamba1.0x~98%

5.2 架构对比

特性TransformerMambaHybridTransMamba
注意力部分动态
SSM部分动态
参数共享
无损转换--
动态调度固定

6. 实现细节

6.1 代码结构

# TransMamba实现结构
class TransMambaModel(nn.Module):
    def __init__(self, config):
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_dim)
        self.layers = nn.ModuleList([
            TransMambaBlock(config) 
            for _ in range(config.num_layers)
        ])
        self.norm = nn.LayerNorm(config.hidden_dim)
        self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size)
        
    def forward(self, input_ids, mode_schedule=None):
        x = self.embeddings(input_ids)
        
        for i, layer in enumerate(self.layers):
            mode = mode_schedule[i] if mode_schedule else 'attention'
            x = layer(x, mode=mode)
        
        return self.norm(x)

6.2 训练配置

config = {
    "model_type": "transmamba",
    "hidden_dim": 2048,
    "intermediate_dim": 5632,
    "num_layers": 24,
    "d_state": 128,
    "transpoint_cycle": 8,
    "transpoint_offset": 2,
    "vocab_size": 50257,
}

7. 总结

TransMamba通过序列级统一实现了Transformer和Mamba的真正融合:

  1. 参数共享:QKV↔CBx复用,减少冗余
  2. 无损转换:Memory Converter保持信息流
  3. 动态调度:TransPoint实现最优模式切换
  4. 效率提升:25%训练速度改善

TransMamba代表了混合架构设计的新方向:不是简单堆叠,而是深度统一


参考资料


相关文档:[[mamba-2-state-space-duality-deep-theory)、[mamba-2-hybrid-architecture-design)、[hybrid-ssm-transformer]]

Footnotes

  1. Li, Y. et al. (2025). TransMamba: A Sequence-Level Hybrid Transformer-Mamba. arXiv:2503.24067.