1. 引言

状态空间模型(State Space Models, SSM)特别是Mamba架构近年来引起了广泛关注。理解SSM的表达能力对于设计更好的序列模型至关重要。

本文从多项式表达能力的角度分析选择性SSM(S6/Mamba)的理论极限,揭示其与Transformer的根本差异。1

2. 从LTI SSM到选择性SSM

2.1 经典LTI SSM

经典的线性时不变(Linear Time-Invariant, LTI)SSM定义为:

其中:

  • :输入
  • :隐藏状态
  • :输出
  • 固定的参数矩阵

2.2 LTI SSM的局限性

LTI SSM的参数不随输入变化,导致:

  • 表达能力受限:无法根据输入内容选择性地处理信息
  • 线性核限制:只能学习线性时不变关系
  • 与RNN类似:缺乏注意力机制的灵活性

2.3 Mamba:选择性SSM

Mamba引入了选择性机制,使参数成为输入的函数:

关键区别:所有参数现在都依赖于输入

3. 多项式表达能力分析

3.1 展开递归

将选择性SSM展开 步:

3.2 多项式形式

核心发现:展开后的S6输出是输入的多变量多项式

具体地,对于单通道情况():

这是关于 次数为 的多变量多项式

3.3 定理

定理 3.1(多项式次数):S6层输出的多项式次数为序列长度 的函数:

相比之下,线性注意力的输出是3次多项式

3.4 表达能力差距

模型单层多项式次数达到 次多项式所需层数
线性注意力3
S6/Mamba1
LTI SSM1

4. 理论深度:为什么选择性强?

4.1 门控机制的作用

选择性的核心是输入依赖的门控

这引入了乘法交互

多次乘法导致多项式次数指数增长。

4.2 对比:线性注意力

线性注意力使用核函数近似:

其中 是特征映射。由于使用加法而非乘法,表达能力受限。

4.3 时间尺度的作用

的选择性控制信息的时间尺度:

  • :快速衰减,只关注近期信息
  • :缓慢衰减,保留长期依赖

这使得S6可以自适应地选择关注的时间范围。

5. 表达能力上界:TC⁰

5.1 电路复杂度回顾

复杂度类电路深度门类型
AC⁰AND, OR, NOT
TC⁰AND, OR, MAJORITY
NC¹AND, OR, NOT
P多项式任意

5.2 主要定理

定理 5.1(TC⁰上界):S4和Mamba都在TC⁰电路复杂度类中。2

即:存在深度 的阈值电路模拟任何 步的SSM。

5.3 推论

  1. 多项式可解:SSM无法高效解决需要超过TC⁰计算的问题
  2. 与Transformer关系:Transformer(Softmax注意力)也在TC⁰中
  3. 表达能力相似:在电路复杂度层面,SSM和Transformer表达能力相当

6. 形式语言视角

6.1 无穷语言识别

定理 6.1(Star-Free语言):SSM可以实现所有star-free正则语言的长度泛化识别。3

Star-free语言包括:

  • 有序语言:
  • 布尔语言:补运算封闭的正则语言
  • 模计数语言:

6.2 与Transformer的对比

语言类型TransformerSSM
有序语言 困难容易
模计数语言困难容易
交换语言容易困难
排列不变性容易困难

6.3 选择性设计的意义

非选择性SSM(S4)由于其线性时间不变性,无法区分:

选择性机制使Mamba可以根据内容选择,从而实现有序依赖。

7. 通用多项式表达

7.1 主要结果

定理 7.1(通用性):4层堆叠的Mamba模型可以表示任意有界次数的多变量多项式

即:对于任何多项式 ,存在一个4层Mamba使得:

7.2 构造性证明

4层的通用多项式表达能力来自:

  1. 层1:计算所有单项式
  2. 层2:通过乘法组合单项式形成任意单项式
  3. 层3:通过加法组合单项式形成任意多项式
  4. 层4:应用非线性输出变换

7.3 与深度的关系

深度表达能力
1层次数 的多项式
2层任意单项式
3层任意多项式(有限项)
4层任意有界次多项式

8. 实际应用:关联召回

8.1 MQAR任务

多查询关联召回(Multi-Query Associative Recall, MQAR)是测试序列模型的关键任务:

任务定义:给定键值对序列:

学习检索:给定查询 ,输出对应的

8.2 Mamba的优势

定理 8.1(MQA召回):S6层可以动态对抗记忆衰减4

关键在于选择性机制允许:

  • 对相关键产生大的注意力权重
  • 对无关键产生小的注意力权重

这通过非线性门控实现,而非简单的线性衰减。

8.3 线性注意力的劣势

线性注意力的递归形式:

这导致:

  • 所有历史键被同等对待
  • 无法动态选择相关信息
  • 记忆以固定速度衰减

9. 表达能力的权衡

9.1 效率 vs 表达

模型时间复杂度空间复杂度多项式次数
完整注意力任意
线性注意力3
S6/Mamba
LTI SSM1

9.2 选择性设计

选择性的代价:

  • 参数数量增加(输入依赖)
  • 计算开销略增(需要额外的投影)
  • 但仍然是线性复杂度

10. 代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class SelectiveSSM(nn.Module):
    """
    Mamba/S6选择性状态空间模型
    """
    def __init__(self, d_model, d_state=16, d_conv=4, dt_min=0.001, dt_max=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # 输入投影
        self.x_proj = nn.Linear(d_model, d_state * 2 + 1, bias=False)
        
        # 状态矩阵A(可学习参数)
        self.A = nn.Parameter(torch.randn(d_state, d_model))
        
        # D(跳连接)
        self.D = nn.Parameter(torch.ones(d_model))
        
        # 卷积
        self.d_conv = d_conv
        self.conv = nn.Conv1d(
            d_model, d_model, d_conv,
            padding=d_conv - 1, groups=d_model
        )
        
        # 输出投影
        self.out_proj = nn.Linear(d_model, d_model)
        
        # dt范围
        self.dt_min = dt_min
        self.dt_max = dt_max
    
    def forward(self, x):
        """
        x: (batch, seq_len, d_model)
        """
        batch, seq_len, d = x.shape
        device = x.device
        
        # 1. 输入投影得到B, C, Δ
        x_proj = self.x_proj(x)  # (B, L, d_state*2 + 1)
        
        B = x_proj[:, :, :self.d_state]  # (B, L, d_state)
        C = x_proj[:, :, self.d_state:self.d_state*2]  # (B, L, d_state)
        dt = x_proj[:, :, -1]  # (B, L)
        
        # 2. Δ: softplus + 裁剪
        dt = F.softplus(dt)
        dt = torch.clamp(dt, min=self.dt_min, max=self.dt_max)
        
        # 3. 离散化: A_bar = exp(Δ ⊗ A)
        # A: (d_state, d_model), dt: (B, L) -> ΔA: (B, L, d_state, d_model)
        # 使用近似: A_bar ≈ I + ΔA
        A_bar = torch.exp(dt.unsqueeze(-1) * self.A)  # (d_state, d_model)
        
        # 4. 投影 B, C
        B = B.unsqueeze(-1) * x.unsqueeze(2)  # (B, L, d_state, d_model) @ (B, L, d_model, 1)
        B = B.sum(dim=-1)  # (B, L, d_state)
        
        C = C.unsqueeze(-1) * x.unsqueeze(2)
        C = C.sum(dim=-1)  # (B, L, d_state)
        
        # 5. 扫描计算隐藏状态
        h = torch.zeros(batch, self.d_state, device=device)
        outputs = []
        
        for t in range(seq_len):
            # h_t = A_bar[t] * h_{t-1} + B[t] * x[t]
            h = A_bar[t] * h + B[:, t, :] * x[:, t, :]
            
            # y_t = C[t] * h + D * x[t]
            y_t = (C[:, t, :] * h).sum(dim=-1) + self.D * x[:, t, :]
            outputs.append(y_t)
        
        y = torch.stack(outputs, dim=1)  # (B, L, d_model)
        
        return self.out_proj(y)
 
 
class PolynomialComplexityAnalyzer:
    """
    分析SSM的多项式表达能力
    """
    @staticmethod
    def compute_polynomial_degree(ssm_layer):
        """
        估算SSM层输出的多项式次数
        """
        # 对于S6: 次数 ≈ 序列长度 + 3
        # 对于线性注意力: 次数 = 3
        # 对于LTI SSM: 次数 = 1
        return {
            's6': lambda seq_len: seq_len + 3,
            'linear_attention': lambda seq_len: 3,
            'lti_ssm': lambda seq_len: 1
        }
    
    @staticmethod
    def layers_for_degree(target_degree, model_type='s6'):
        """
        计算达到目标多项式次数所需的层数
        """
        if model_type == 's6':
            # 单层即可达到任意次数
            return 1
        elif model_type == 'linear_attention':
            # 需要log层达到L次
            import math
            return math.ceil(math.log2(target_degree / 3))
        elif model_type == 'lti_ssm':
            return target_degree
        else:
            raise ValueError(f"Unknown model type: {model_type}")
 
 
def test_polynomial_expressivity():
    """
    测试不同SSM变体的多项式表达能力
    """
    torch.manual_seed(42)
    
    seq_len = 16
    d_model = 32
    d_state = 8
    
    # 创建模型
    ssm = SelectiveSSM(d_model, d_state)
    ssm.eval()
    
    # 生成测试数据
    x = torch.randn(2, seq_len, d_model)
    
    with torch.no_grad():
        y = ssm(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {y.shape}")
    print(f"Output range: [{y.min():.4f}, {y.max():.4f}]")
    
    # 分析多项式次数
    analyzer = PolynomialComplexityAnalyzer()
    s6_degree = analyzer.compute_polynomial_degree('s6')(seq_len)
    la_degree = analyzer.compute_polynomial_degree('linear_attention')(seq_len)
    
    print(f"\nS6 polynomial degree: O({s6_degree})")
    print(f"Linear attention polynomial degree: O({la_degree})")
    print(f"Speedup factor: {s6_degree / la_degree:.1f}x more expressive")
 
 
if __name__ == '__main__':
    test_polynomial_expressivity()

11. 与现有理论的关系

11.1 与TC⁰电路复杂度

S6的多项式表达能力与其TC⁰电路复杂度上界一致:

  • TC⁰电路可以计算阈值函数
  • 多项式可以通过阈值函数构造
  • 因此表达能力等价

11.2 与状态跟踪

定理:S6无法高效跟踪任意复杂状态:

这意味着对于需要强计算能力的状态跟踪任务,SSM仍然受限。

11.3 与Transformer的关系

性质S6/MambaTransformer
时间复杂度
空间复杂度
多项式次数任意
TC⁰能力
长度泛化弱(某些任务)

12. 总结

从多项式表达能力的角度分析选择性SSM,揭示了以下关键发现:

  1. 指数级表达差距:S6单层可表达 次多项式,而线性注意力仅能表达3次
  2. 通用多项式表达:4层S6可以表达任意有界次多项式
  3. TC⁰上界:S6的计算能力在TC⁰复杂度类中
  4. 形式语言优势:S6在star-free语言和有序依赖任务上优于Transformer
  5. 实践权衡:选择性强带来了表达能力提升,但仍然是线性复杂度

这些发现为理解和设计下一代状态空间模型提供了理论基础。


Footnotes

  1. Cohen-Karlik, E., et al. “On the Expressivity of Selective State-Space Layers: A Multivariate Polynomial Approach.” arXiv 2025. https://arxiv.org/abs/2502.02209

  2. Merrill, W., et al. “Mamba: Linear-Time Sequence Modeling with Selective State Spaces.” arXiv 2023. https://arxiv.org/abs/2312.00752

  3. Sarrof, S., et al. “The Expressive Capacity of State Space Models: A Formal Characterization.” arXiv 2024. https://arxiv.org/abs/2405.17394

  4. 本理论与Mamba表达能力理论状态空间模型密切相关。