Transformer-SSM 混合架构理论

近年来,Transformer状态空间模型(SSM) 的混合架构成为 LLM 研究的热点。Transformer 具有强大的全局注意力能力,但计算复杂度为 ;SSM(如 Mamba)具有 的线性复杂度,但长程依赖建模较弱。混合架构旨在结合两者优势。本文系统梳理混合架构的理论基础、设计模式、代表模型和最新进展。


1. 动机:为什么需要混合

1.1 Transformer 的优势与不足

优势

  • 全局注意力:单层即可建模任意距离的依赖
  • 成熟的生态系统:训练、推理、优化工具完善
  • 强大的 in-context learning 能力

不足

  • 计算复杂度,长序列昂贵
  • KV Cache 内存:推理时与序列长度线性增长
  • 外推困难:训练长度有限时泛化到长序列挑战

1.2 SSM 的优势与不足

优势

  • 线性复杂度(取决于实现)
  • 固定状态:推理时内存恒定
  • 循环表示:天然支持无限长度

不足

  • 表达力限制:对某些复杂模式建模困难
  • 检索能力弱:在需要精确查找的任务上劣于注意力
  • 训练复杂:选择性 SSM 等的稳定性挑战

1.3 互补性分析

能力TransformerSSM
全局注意力✅ 强⚠️ 中
长程记忆⚠️ KV Cache✅ 固定状态
精确检索✅ 强❌ 弱
复制任务✅ 强⚠️ 中
计算效率❌ O(n²)✅ O(n)
状态空间建模⚠️ 中✅ 强

结论:两者高度互补,混合架构有理论动机。


2. SSM 基础回顾

2.1 连续状态空间

连续时间 SSM

其中:

  • :隐藏状态
  • :输入
  • :参数矩阵

2.2 离散化

零阶保持

离散 SSM

2.3 选择性 SSM(Mamba)

核心创新:让 依赖输入:

选择性扫描:使用并行扫描算法高效计算。

2.4 SSM 与注意力的关系

关键洞察(Ali et al. 20251):

Mamba 层隐式地实现了某种注意力

形式化:选择性 SSM 可重写为:

其中 输入依赖的权重


3. 混合模式分类

3.1 层级混合(Layer-level Hybrid)

定义:在同一模型中,部分层用 Transformer,部分层用 SSM。

[Transformer] → [Transformer] → [SSM] → [SSM] → [Transformer] → ...

代表

  • Jamba(AI21)
  • Zamba
  • Samba
  • Hymba

3.2 序列级混合(Sequence-level Hybrid)

定义:将输入序列分成段,不同段使用不同模型。

token 1-512  → Transformer
token 513-1024 → SSM
token 1025-1536 → Transformer
...

代表

  • TransMamba(Tencent)
  • RecurrentGemma

3.3 Token 级混合(Token-level Hybrid)

定义:每个 token 由 Transformer 或 SSM 处理(通过路由)。

token 1 → Transformer
token 2 → SSM
token 3 → Transformer
token 4 → SSM

代表

  • MoE-Mamba
  • Switch SSM

3.4 块级混合(Block-level Hybrid)

定义:在同一层内,Transformer 和 SSM 并行处理,然后合并。

         ┌─ Transformer ─┐
input → ├─ SSM ──────────├─ merge → output
         └─ 其他 ────────┘

代表

  • Jamba 的某些变体

4. 代表模型详解

4.1 Jamba(AI21)

核心论文:Lieber et al. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv:2403.19887.

架构

每 8 层中有:
- 7 层 Transformer
- 1 层 Mamba

核心组件

  • Transformer 块:标准注意力 + MoE FFN
  • Mamba 块:选择性 SSM + MoE FFN
  • 共享 MoE:跨块使用

配置(Jamba 1.5 Large):

  • 52 层(Transformer:Mamba = 7:1)
  • 94B 总参数(12B 激活)
  • 256K 上下文

4.2 Jamba 架构实现

class JambaBlock(nn.Module):
    """Jamba 风格的混合 Block"""
    
    def __init__(self, d_model, num_heads, d_ff, mamba_d_state, mamba_d_conv,
                 use_attention=True, use_mamba=True, use_moe=True):
        super().__init__()
        
        self.use_attention = use_attention
        self.use_mamba = use_mamba
        
        if use_attention:
            self.attn = MultiHeadAttention(d_model, num_heads)
            self.ln_attn = RMSNorm(d_model)
        
        if use_mamba:
            self.mamba = MambaBlock(d_model, mamba_d_state, mamba_d_conv)
            self.ln_mamba = RMSNorm(d_model)
        
        # MoE FFN
        self.ln_ffn = RMSNorm(d_model)
        self.ffn = MoEFFN(d_model, d_ff) if use_moe else SwiGLU(d_model, d_ff)
    
    def forward(self, x, mask=None):
        if self.use_attention:
            x = x + self.attn(self.ln_attn(x), mask=mask)
        
        if self.use_mamba:
            x = x + self.mamba(self.ln_mamba(x))
        
        x = x + self.ffn(self.ln_ffn(x))
        return x

4.3 Mamba-2

核心论文:Dao & Gu (2024). Transformers are SSMs.

关键洞察:Transformer 的注意力可视为 SSM 的特例。

形式化

SSD(State Space Duality):Mamba-2 利用这种对偶性。

4.4 TransMamba

核心论文:Li et al. (2025). TransMamba: A Sequence-Level Hybrid Transformer-Mamba Language Model. arXiv:2503.24067.

架构

  • 早期层用 SSM(捕获局部模式)
  • 后层用 Transformer(全局推理)

优点

  • 早期层的线性复杂度降低总计算
  • 后层的全局注意力保持质量

4.5 Hymba

核心创新:使用混合头

  • 部分头是注意力
  • 部分头是 SSM
  • 通过学习路由

4.6 混合架构参数对比

模型总参数激活参数上下文Transformer:Mamba 比
Jamba 1.5 Large94B12B256K7:1
Zamba 7B7B7B32K6:1
RecurrentGemma9B9B8K5:1
TransMamba7B7B32K1:1 (段级)
Hymba-1.5B1.5B1.5B8K头级混合

5. 系统设计洞察

5.1 核心论文

Bae, Acun, Lin, Habeeb, Kim, Luo, Wang, Wu (2026). Hybrid Architectures for Language Models: Systematic Analysis and Design Insights. arXiv:2510.04800.2

5.2 关键实验发现

实验 1:混合比例

  • 最佳比例:Transformer 占比 70-85%
  • 极端比例(纯 Transformer 或纯 SSM)效果差

实验 2:放置位置

  • 最佳位置:Transformer 在中后层(处理高阶推理)
  • SSM 在前后层(处理局部和长程模式)

实验 3:训练稳定性

  • 混合架构训练比纯 Transformer 更稳定
  • SSM 提供”梯度平滑”效应

5.3 设计原则

基于系统分析,作者提出以下设计原则:

原则 1:注意力集中在中层

中层(占总层数 30-70%)使用 Transformer,捕获复杂模式。

原则 2:SSM 承担长程依赖

SSM 适合处理远距离的”长程上下文”信息。

原则 3:共享参数

注意力头和 SSM 状态可共享部分参数。

原则 4:避免局部注意力

与 SSM 相比,局部注意力通常效果差。

5.4 系统实验结果

┌────────────────────────────────────────┐
│  任务类型    │ 纯 Transformer │ 混合 │ 纯 SSM │
├────────────────────────────────────────┤
│  语言建模    │     ★★★★      │ ★★★★★│  ★★★  │
│  长上下文    │     ★★★       │ ★★★★★│  ★★★★ │
│  检索任务    │     ★★★★★     │ ★★★★ │  ★★   │
│  推理任务    │     ★★★★★     │ ★★★★ │  ★★★  │
│  训练速度    │     ★★        │ ★★★★ │  ★★★★★│
└────────────────────────────────────────┘

6. 理论分析

6.1 表达能力

定理 6.1:单层选择性 SSM 与单层线性注意力等价(在某些参数化下)。

含义:SSM 与注意力不是完全不同的模型,而是同一家族的成员

6.2 计算复杂度

Transformer

  • 注意力矩阵:
  • 矩阵乘法:

SSM

  • 选择性扫描:
  • 矩阵乘法:(主导项)

混合

  • 其中 是 Transformer 比例

6.3 内存复杂度

Transformer 推理

  • 模型参数:
  • KV Cache:(随长度增长)

SSM 推理

  • 模型参数:
  • 状态:(固定)

混合推理

  • 模型参数:
  • 部分 KV Cache + 固定状态

6.4 长程依赖能力

关键定理(形式化):

Transformer 在单层内可编码任意位置对的依赖。
SSM 通过循环隐藏状态可编码固定长度的依赖。

含义

  • Transformer:理论无限长程(实际受限于训练长度)
  • SSM:固定窗口长程(与状态维度 相关)

混合:Transformer 处理任意距离 + SSM 巩固长程记忆。


7. Mamba 的隐藏注意力

7.1 核心论文

Ali, Zimerman, Wolf (2025). The Hidden Attention of Mamba Models. ACL 2025.3

7.2 关键发现

Mamba 在内部实现了类似注意力的机制

形式化:Mamba 的输出可重写为:

其中 输入依赖的权重。

与注意力的差异

  • 注意力:
  • Mamba 隐藏注意力: 通过 SSM 循环动态生成

7.3 实验验证

研究方法

  1. 在 Mamba 模型上做”注意力可视化”
  2. 找出 与真正注意力的相似度

发现

  • Mamba 的”隐藏注意力”确实存在
  • 模式与真正的注意力部分相似
  • 效率不同

7.4 含义

理论含义

  • 纯 SSM 与混合架构不是非此即彼
  • 某些 SSM 已经隐式包含”注意力”

实践含义

  • 即使纯 SSM 模型也有”长程能力”
  • 混合架构的边界模糊

8. 训练动力学

8.1 混合架构的训练优势

稳定性

  • SSM 提供梯度平滑
  • Transformer 提供表达力
  • 组合效果更好

收敛速度

  • 通常比纯 Transformer 更快收敛到同等性能

8.2 损失景观

混合架构的损失景观

  • 较”平坦”
  • 局部最小值更少
  • 对超参数不敏感

8.3 训练策略

class HybridTrainingStrategy:
    """混合架构训练策略"""
    
    def __init__(self, model):
        self.model = model
    
    def get_layer_learning_rates(self):
        """为不同层设置不同学习率"""
        rates = []
        for layer in self.model.layers:
            if isinstance(layer, TransformerBlock):
                # Transformer 层使用稍低学习率
                rates.append(0.9 * self.base_lr)
            else:  # MambaBlock
                rates.append(self.base_lr)
        return rates
    
    def selective_warmup(self, step):
        """选择性 warmup"""
        if step < 1000:
            # 只训练 SSM 部分
            for p in self.model.transformer_params():
                p.requires_grad = False
        else:
            # 训练全部
            for p in self.model.parameters():
                p.requires_grad = True

9. 推理优化

9.1 推理流程

Transformer 部分

  • 标准 KV Cache 管理
  • Flash Attention 加速

SSM 部分

  • 固定状态维护
  • 并行扫描

9.2 内存优化

class HybridInferenceEngine:
    """混合架构推理引擎"""
    
    def __init__(self, model, max_seq_len=128000):
        self.model = model
        self.max_seq_len = max_seq_len
        
        # KV Cache 仅用于 Transformer 部分
        self.kv_cache = {}
        # SSM 状态固定
        self.ssm_states = {}
    
    def forward(self, x, past_state=None):
        for i, layer in enumerate(self.model.layers):
            if isinstance(layer, TransformerBlock):
                # 标准 transformer 推理 + KV Cache
                x, kv = layer(x, past_kv=self.kv_cache.get(i))
                self.kv_cache[i] = kv
            else:
                # SSM 推理 + 状态
                x, state = layer(x, state=self.ssm_states.get(i))
                self.ssm_states[i] = state
        return x

9.3 吞吐量优势

典型结果

  • Jamba 与同规模 Transformer 相比:2-3x 吞吐量提升
  • 主要来自长上下文场景

10. 评估混合架构

10.1 评估基准

任务评估指标
长上下文检索needle-in-haystack
长上下文推理LongBench, RULER
语言建模perplexity
标准基准MMLU, HellaSwag
推理速度tokens/second

10.2 评估代码

def evaluate_hybrid_model(model, eval_datasets):
    """评估混合架构模型"""
    results = {}
    
    for name, dataset in eval_datasets.items():
        if 'long' in name.lower():
            # 长上下文评估
            results[name] = evaluate_long_context(model, dataset)
        elif 'reasoning' in name.lower():
            # 推理评估
            results[name] = evaluate_reasoning(model, dataset)
        else:
            # 标准评估
            results[name] = evaluate_standard(model, dataset)
    
    return results

11. 设计实践指南

11.1 选择混合比例

基于任务

def recommend_hybrid_ratio(target_task, target_seq_len):
    """根据任务推荐混合比例"""
    
    if target_task in ['retrieval', 'precise_lookup']:
        # 检索任务需要更多注意力
        return {'transformer_ratio': 0.8, 'ssm_ratio': 0.2}
    
    elif target_task in ['long_context_summarization']:
        # 长上下文任务需要 SSM
        return {'transformer_ratio': 0.5, 'ssm_ratio': 0.5}
    
    elif target_seq_len > 100000:
        # 超长序列
        return {'transformer_ratio': 0.3, 'ssm_ratio': 0.7}
    
    else:
        # 默认
        return {'transformer_ratio': 0.7, 'ssm_ratio': 0.3}

11.2 层放置策略

经验法则

  • 浅层(25%):SSM(局部模式)
  • 中层(50%):Transformer(核心推理)
  • 深层(25%):Transformer + SSM(输出整合)

11.3 超参数选择

参数Transformer 部分SSM 部分
学习率标准标准(可略高)
Warmup 步数标准较少
初始化标准谨慎(HiPPO 等)

12. 未来方向

12.1 待解决问题

  1. 最优混合比例:是否有理论指导?
  2. 动态路由:是否能根据输入自适应选择?
  3. 统一架构:能否设计一个统一框架容纳两者?

12.2 潜在方向

方向 1:状态空间注意力

将注意力视为 SSM 的特例,统一两者。

方向 2:神经架构搜索

自动搜索最优混合架构。

方向 3:模块化设计

设计可插拔的混合模块。


13. 关键论文清单

基础理论

  1. Gu, Goel, Ré (2022) — Efficiently Modeling Long Sequences with Structured State Spaces (S4)
  2. Gu & Dao (2023) — Mamba
  3. Dao & Gu (2024) — Transformers are SSMs (Mamba-2)

混合架构

  1. Lieber et al. (2024) — Jamba
  2. Li et al. (2025) — TransMamba
  3. Bae et al. (2026) — Hybrid Architectures: Systematic Analysis (FAIR Meta)

理论分析

  1. Ali et al. (2025) — Hidden Attention of Mamba (ACL 2025)
  2. Wang et al. (2024) — State Space Duality Theory

实践

  1. NVIDIA (2024) — HybridSSM
  2. Microsoft (2025) — Hymba

14. 与相关专题的连接

14.1 Transformer 架构专题

14.2 Mamba/SSM 相关

14.3 应用


最后更新:2026-06-21

Footnotes

  1. Ali, Zimerman, Wolf (2025). The Hidden Attention of Mamba Models. ACL 2025.

  2. Bae et al. (2026). Hybrid Architectures: Systematic Analysis. arXiv:2510.04800. (FAIR at Meta)

  3. Ali et al. (2025). Hidden Attention of Mamba. ACL 2025.