1. 引言
状态空间模型(State Space Models, SSM)特别是Mamba架构近年来引起了广泛关注。理解SSM的表达能力对于设计更好的序列模型至关重要。
本文从多项式表达能力的角度分析选择性SSM(S6/Mamba)的理论极限,揭示其与Transformer的根本差异。1
2. 从LTI SSM到选择性SSM
2.1 经典LTI SSM
经典的线性时不变(Linear Time-Invariant, LTI)SSM定义为:
其中:
- :输入
- :隐藏状态
- :输出
- :固定的参数矩阵
2.2 LTI SSM的局限性
LTI SSM的参数不随输入变化,导致:
- 表达能力受限:无法根据输入内容选择性地处理信息
- 线性核限制:只能学习线性时不变关系
- 与RNN类似:缺乏注意力机制的灵活性
2.3 Mamba:选择性SSM
Mamba引入了选择性机制,使参数成为输入的函数:
关键区别:所有参数现在都依赖于输入 。
3. 多项式表达能力分析
3.1 展开递归
将选择性SSM展开 步:
3.2 多项式形式
核心发现:展开后的S6输出是输入的多变量多项式。
具体地,对于单通道情况():
这是关于 的次数为 的多变量多项式。
3.3 定理
定理 3.1(多项式次数):S6层输出的多项式次数为序列长度 的函数:
相比之下,线性注意力的输出是3次多项式:
3.4 表达能力差距
| 模型 | 单层多项式次数 | 达到 次多项式所需层数 |
|---|---|---|
| 线性注意力 | 3 | |
| S6/Mamba | 1 | |
| LTI SSM | 1 |
4. 理论深度:为什么选择性强?
4.1 门控机制的作用
选择性的核心是输入依赖的门控:
这引入了乘法交互:
多次乘法导致多项式次数指数增长。
4.2 对比:线性注意力
线性注意力使用核函数近似:
其中 是特征映射。由于使用加法而非乘法,表达能力受限。
4.3 时间尺度的作用
的选择性控制信息的时间尺度:
- 小 :快速衰减,只关注近期信息
- 大 :缓慢衰减,保留长期依赖
这使得S6可以自适应地选择关注的时间范围。
5. 表达能力上界:TC⁰
5.1 电路复杂度回顾
| 复杂度类 | 电路深度 | 门类型 |
|---|---|---|
| AC⁰ | AND, OR, NOT | |
| TC⁰ | AND, OR, MAJORITY | |
| NC¹ | AND, OR, NOT | |
| P | 多项式 | 任意 |
5.2 主要定理
定理 5.1(TC⁰上界):S4和Mamba都在TC⁰电路复杂度类中。2
即:存在深度 的阈值电路模拟任何 步的SSM。
5.3 推论
- 多项式可解:SSM无法高效解决需要超过TC⁰计算的问题
- 与Transformer关系:Transformer(Softmax注意力)也在TC⁰中
- 表达能力相似:在电路复杂度层面,SSM和Transformer表达能力相当
6. 形式语言视角
6.1 无穷语言识别
定理 6.1(Star-Free语言):SSM可以实现所有star-free正则语言的长度泛化识别。3
Star-free语言包括:
- 有序语言:
- 布尔语言:补运算封闭的正则语言
- 模计数语言:
6.2 与Transformer的对比
| 语言类型 | Transformer | SSM |
|---|---|---|
| 有序语言 | 困难 | 容易 |
| 模计数语言 | 困难 | 容易 |
| 交换语言 | 容易 | 困难 |
| 排列不变性 | 容易 | 困难 |
6.3 选择性设计的意义
非选择性SSM(S4)由于其线性时间不变性,无法区分:
- 和
- 和
选择性机制使Mamba可以根据内容选择,从而实现有序依赖。
7. 通用多项式表达
7.1 主要结果
定理 7.1(通用性):4层堆叠的Mamba模型可以表示任意有界次数的多变量多项式。
即:对于任何多项式 ,存在一个4层Mamba使得:
7.2 构造性证明
4层的通用多项式表达能力来自:
- 层1:计算所有单项式
- 层2:通过乘法组合单项式形成任意单项式
- 层3:通过加法组合单项式形成任意多项式
- 层4:应用非线性输出变换
7.3 与深度的关系
| 深度 | 表达能力 |
|---|---|
| 1层 | 次数 的多项式 |
| 2层 | 任意单项式 |
| 3层 | 任意多项式(有限项) |
| 4层 | 任意有界次多项式 |
8. 实际应用:关联召回
8.1 MQAR任务
多查询关联召回(Multi-Query Associative Recall, MQAR)是测试序列模型的关键任务:
任务定义:给定键值对序列:
学习检索:给定查询 ,输出对应的 。
8.2 Mamba的优势
定理 8.1(MQA召回):S6层可以动态对抗记忆衰减。4
关键在于选择性机制允许:
- 对相关键产生大的注意力权重
- 对无关键产生小的注意力权重
这通过非线性门控实现,而非简单的线性衰减。
8.3 线性注意力的劣势
线性注意力的递归形式:
这导致:
- 所有历史键被同等对待
- 无法动态选择相关信息
- 记忆以固定速度衰减
9. 表达能力的权衡
9.1 效率 vs 表达
| 模型 | 时间复杂度 | 空间复杂度 | 多项式次数 |
|---|---|---|---|
| 完整注意力 | 任意 | ||
| 线性注意力 | 3 | ||
| S6/Mamba | |||
| LTI SSM | 1 |
9.2 选择性设计
选择性的代价:
- 参数数量增加(输入依赖)
- 计算开销略增(需要额外的投影)
- 但仍然是线性复杂度
10. 代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelectiveSSM(nn.Module):
"""
Mamba/S6选择性状态空间模型
"""
def __init__(self, d_model, d_state=16, d_conv=4, dt_min=0.001, dt_max=0.1):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# 输入投影
self.x_proj = nn.Linear(d_model, d_state * 2 + 1, bias=False)
# 状态矩阵A(可学习参数)
self.A = nn.Parameter(torch.randn(d_state, d_model))
# D(跳连接)
self.D = nn.Parameter(torch.ones(d_model))
# 卷积
self.d_conv = d_conv
self.conv = nn.Conv1d(
d_model, d_model, d_conv,
padding=d_conv - 1, groups=d_model
)
# 输出投影
self.out_proj = nn.Linear(d_model, d_model)
# dt范围
self.dt_min = dt_min
self.dt_max = dt_max
def forward(self, x):
"""
x: (batch, seq_len, d_model)
"""
batch, seq_len, d = x.shape
device = x.device
# 1. 输入投影得到B, C, Δ
x_proj = self.x_proj(x) # (B, L, d_state*2 + 1)
B = x_proj[:, :, :self.d_state] # (B, L, d_state)
C = x_proj[:, :, self.d_state:self.d_state*2] # (B, L, d_state)
dt = x_proj[:, :, -1] # (B, L)
# 2. Δ: softplus + 裁剪
dt = F.softplus(dt)
dt = torch.clamp(dt, min=self.dt_min, max=self.dt_max)
# 3. 离散化: A_bar = exp(Δ ⊗ A)
# A: (d_state, d_model), dt: (B, L) -> ΔA: (B, L, d_state, d_model)
# 使用近似: A_bar ≈ I + ΔA
A_bar = torch.exp(dt.unsqueeze(-1) * self.A) # (d_state, d_model)
# 4. 投影 B, C
B = B.unsqueeze(-1) * x.unsqueeze(2) # (B, L, d_state, d_model) @ (B, L, d_model, 1)
B = B.sum(dim=-1) # (B, L, d_state)
C = C.unsqueeze(-1) * x.unsqueeze(2)
C = C.sum(dim=-1) # (B, L, d_state)
# 5. 扫描计算隐藏状态
h = torch.zeros(batch, self.d_state, device=device)
outputs = []
for t in range(seq_len):
# h_t = A_bar[t] * h_{t-1} + B[t] * x[t]
h = A_bar[t] * h + B[:, t, :] * x[:, t, :]
# y_t = C[t] * h + D * x[t]
y_t = (C[:, t, :] * h).sum(dim=-1) + self.D * x[:, t, :]
outputs.append(y_t)
y = torch.stack(outputs, dim=1) # (B, L, d_model)
return self.out_proj(y)
class PolynomialComplexityAnalyzer:
"""
分析SSM的多项式表达能力
"""
@staticmethod
def compute_polynomial_degree(ssm_layer):
"""
估算SSM层输出的多项式次数
"""
# 对于S6: 次数 ≈ 序列长度 + 3
# 对于线性注意力: 次数 = 3
# 对于LTI SSM: 次数 = 1
return {
's6': lambda seq_len: seq_len + 3,
'linear_attention': lambda seq_len: 3,
'lti_ssm': lambda seq_len: 1
}
@staticmethod
def layers_for_degree(target_degree, model_type='s6'):
"""
计算达到目标多项式次数所需的层数
"""
if model_type == 's6':
# 单层即可达到任意次数
return 1
elif model_type == 'linear_attention':
# 需要log层达到L次
import math
return math.ceil(math.log2(target_degree / 3))
elif model_type == 'lti_ssm':
return target_degree
else:
raise ValueError(f"Unknown model type: {model_type}")
def test_polynomial_expressivity():
"""
测试不同SSM变体的多项式表达能力
"""
torch.manual_seed(42)
seq_len = 16
d_model = 32
d_state = 8
# 创建模型
ssm = SelectiveSSM(d_model, d_state)
ssm.eval()
# 生成测试数据
x = torch.randn(2, seq_len, d_model)
with torch.no_grad():
y = ssm(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Output range: [{y.min():.4f}, {y.max():.4f}]")
# 分析多项式次数
analyzer = PolynomialComplexityAnalyzer()
s6_degree = analyzer.compute_polynomial_degree('s6')(seq_len)
la_degree = analyzer.compute_polynomial_degree('linear_attention')(seq_len)
print(f"\nS6 polynomial degree: O({s6_degree})")
print(f"Linear attention polynomial degree: O({la_degree})")
print(f"Speedup factor: {s6_degree / la_degree:.1f}x more expressive")
if __name__ == '__main__':
test_polynomial_expressivity()11. 与现有理论的关系
11.1 与TC⁰电路复杂度
S6的多项式表达能力与其TC⁰电路复杂度上界一致:
- TC⁰电路可以计算阈值函数
- 多项式可以通过阈值函数构造
- 因此表达能力等价
11.2 与状态跟踪
定理:S6无法高效跟踪任意复杂状态:
这意味着对于需要强计算能力的状态跟踪任务,SSM仍然受限。
11.3 与Transformer的关系
| 性质 | S6/Mamba | Transformer |
|---|---|---|
| 时间复杂度 | ||
| 空间复杂度 | ||
| 多项式次数 | 任意 | |
| TC⁰能力 | ✓ | ✓ |
| 长度泛化 | 强 | 弱(某些任务) |
12. 总结
从多项式表达能力的角度分析选择性SSM,揭示了以下关键发现:
- 指数级表达差距:S6单层可表达 次多项式,而线性注意力仅能表达3次
- 通用多项式表达:4层S6可以表达任意有界次多项式
- TC⁰上界:S6的计算能力在TC⁰复杂度类中
- 形式语言优势:S6在star-free语言和有序依赖任务上优于Transformer
- 实践权衡:选择性强带来了表达能力提升,但仍然是线性复杂度
这些发现为理解和设计下一代状态空间模型提供了理论基础。
Footnotes
-
Cohen-Karlik, E., et al. “On the Expressivity of Selective State-Space Layers: A Multivariate Polynomial Approach.” arXiv 2025. https://arxiv.org/abs/2502.02209 ↩
-
Merrill, W., et al. “Mamba: Linear-Time Sequence Modeling with Selective State Spaces.” arXiv 2023. https://arxiv.org/abs/2312.00752 ↩
-
Sarrof, S., et al. “The Expressive Capacity of State Space Models: A Formal Characterization.” arXiv 2024. https://arxiv.org/abs/2405.17394 ↩
-
本理论与Mamba表达能力理论和状态空间模型密切相关。 ↩