Sessa:选择性状态空间注意力
概述
Sessa(Selective State Space Attention)是一种新型序列建模架构,通过在反馈(循环)通路中嵌入自注意力来解决传统Transformer和状态空间模型(SSM)的长程依赖建模问题。
核心创新:Sessa引入**多跳路由(Multi-hop Routing)**机制,通过反馈注意力构建可变跳数的路径,理论上实现了比Transformer更慢的遗忘和比Mamba更灵活的选择性检索。
核心论文:arXiv:2604.185801
代码实现:GitHub - LibratioAI/sessa
1. 背景与动机
长上下文建模的两大挑战
Sessa指出现有序列模型面临两个互补的长程依赖失效模式:
1.1 注意力扩散(Attention Diffusion)
标准Transformer的自注意力在处理长序列时会出现token影响力稀释问题:
- 注意力分数分布在越来越多的token上
- 早期token的梯度信号衰减为
- 当注意力”diffuse”(分散)时,检索变得不精确
# 注意力扩散示例
class AttentionDiffusion:
"""演示注意力扩散问题"""
def compute_gradient_decay(self, seq_len, n_heads):
"""
计算早期token的梯度衰减
结论:当序列长度为ℓ时,早期token梯度 ~ O(1/ℓ)
"""
decay_rate = 1.0 / seq_len # O(1/ℓ) 衰减
return decay_rate
def problem_description(self):
"""
问题:当token数增加时:
1. 注意力权重变得更分散
2. 单个token的平均影响力下降
3. 精确检索变得困难
"""
pass1.2 指数遗忘(Exponential Forgetting)
Mamba等SSM面临的问题是指数级遗忘:
- 信息通过线性时不变(LTI)系统传递
- 旧token的信号以指数速度衰减
- 只有在”freeze time”(冻结时间)内才能保持长程依赖
// Mamba的指数遗忘问题
// S4/SSM的前向传播
template <typename T>
T mamba_forward(T x_t, const T& A, const T& B, const T& C) {
// 状态更新:h_t = A * h_{t-1} + B * x_t
// 输出:y_t = C * h_t
//
// 问题:A的固有特性导致指数衰减
// 对于|A| < 1,系统呈现指数遗忘
T h_t = A * h_prev + B * x_t; // h_t = A * h_{t-1} + B * x_t
T y_t = C * h_t; // y_t = C * h_t
return y_t;
}现有方法的局限
| 模型 | 路径数 | 跳数 | 长程衰减 | 选择性检索 |
|---|---|---|---|---|
| Transformer | 1 | 1 | 有限(diffuse时失效) | |
| Mamba | 1 | 多 | 指数 | 有限(freeze时失效) |
| Sessa | 多 | 多 | 支持 |
2. 核心架构:反馈注意力
2.1 关键洞察
Sessa的核心洞察是:将自注意力嵌入到反馈(循环)通路中,构建一个下三角路由矩阵 ,使得信息可以通过多条不同跳数的路径传递。
2.2 架构图示
┌─────────────────────────────────────────────────────────────────┐
│ Sessa 架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 前向注意力(Forward Attention) │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Q_t = x_t @ W_Q │ │
│ │ K_t = x_t @ W_K │ │
│ │ V_t = x_t @ W_V │ │
│ │ f_t = softmax(Q_t @ K_t^T) @ V_t // 前向信号 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ ↓ │
│ 反馈注意力(Feedback Attention) │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ // 反馈注意力构建下三角路由矩阵 │ │
│ │ for t in range(1, T+1): │ │
│ │ B_fb[t, :t] = attention(Q_t, K_{:t}, V_{:t}) │ │
│ │ // s_t 通过多跳路径聚合信息 │ │
│ │ (I - B_fb) · s = f │ │
│ └─────────────────────────────────────────────────────────┘ │
│ ↓ │
│ 输出: y_t = γ_t · s_t + (1-γ_t) · f_t │
│ │
└─────────────────────────────────────────────────────────────────┘
2.3 数学形式化
前向注意力(Forward Attention)
其中 是标准注意力query、key、value。
反馈注意力(Feedback Attention)
反馈注意力构建一个下三角矩阵 :
这个矩阵表示严格的过去注意——每个位置只能关注它之前的token。
核心方程
展开为:
通过前向替换求解:
def forward_substitution(B_fb, f):
"""
前向替换求解 (I - B_fb) s = f
Args:
B_fb: 下三角反馈注意力矩阵 [T, T]
f: 前向注意力输出 [T, d]
Returns:
s: 反馈状态序列 [T, d]
"""
T, d = f.shape
s = torch.zeros_like(f)
# s[0] = f[0]
s[0] = f[0]
# 逐时间步求解
for t in range(1, T):
# s[t] = f[t] + B_fb[t, :t] @ s[:t]
s[t] = f[t] + B_fb[t, :t] @ s[:t]
return s2.4 输出混合
最终输出通过可学习的反馈增益 混合:
其中 通过门控机制学习,控制前向和反馈路径的贡献比例。
3. 多跳路由机制
3.1 为什么需要多跳?
传统模型的问题在于单路径信息流:
- Transformer:信息通过一跳(one-hop)传递,每个token只直接与所有其他token交互一次
- Mamba:信息通过多个时间步传递,但仍是一条链(one chain)
Sessa通过反馈注意力实现了多跳路由,信息可以通过不同跳数的路径聚合:
class MultiHopRouting:
"""
Sessa的多跳路由机制
信息可以通过以下路径到达位置t:
- 跳数1:f_t(直接前向注意力)
- 跳数2:B_fb[t, t-1] @ s_{t-1}
- 跳数3:B_fb[t, t-2] @ s_{t-2}
- ...
- 跳数t:所有早期位置
"""
def analyze_paths(self, t):
"""
分析到达位置t的所有可能路径
路径数量 = t(与到位置t的距离成正比)
跳数范围 = [1, t]
"""
paths = []
for tau in range(t): # tau是回溯的距离
n_hops = tau + 1 # 跳数 = 回溯距离 + 1
paths.append({
'source': t - tau - 1,
'target': t,
'n_hops': n_hops,
'route': f"f_{t-tau} → ... → s_t"
})
return paths3.2 信息传递分析
def analyze_information_flow(T, beta=0.5):
"""
分析Sessa的信息传递特性
假设:注意力在严格过去上diffuse分布
定理:在diffuse假设下,到位置t的信息满足:
- 跳数分布:P(n_hops = k) ∝ (1-β)^k
- 最终影响:O(t^{-β}) 对于β ∈ (0, 1)
"""
print(f"序列长度T={T}的信息传递分析:")
print("-" * 50)
# 在diffuse假设下,注意力均匀分布在所有早期token上
# 经过k次反馈后的衰减
for beta in [0.3, 0.5, 0.7]:
print(f"\nβ = {beta}:")
for t in [100, 500, 1000]:
# 晚期token的影响 ~ O(t^(-β))
influence = t ** (-beta)
print(f" 位置{t}的影响: O({influence:.4f})")
print("\n结论:比Transformer的O(1/t)和Mamba的指数衰减都慢!")3.3 稳定性保证
Sessa通过限制反馈增益 保证BIBO(有界输入有界输出)稳定性:
class StabilityGuarantee:
"""
Sessa的BIBO稳定性分析
系统:(I - B_fb) s = f
等价于:s = (I - B_fb)^(-1) f
稳定性条件:
|γ_t| < 1, ∀t
这保证了:
1. 系统不会发散
2. 输入有界 → 输出有界
"""
@staticmethod
def check_stability(gamma):
"""
检查系统稳定性
Returns:
bool: 是否满足BIBO稳定性
"""
return torch.all(torch.abs(gamma) < 1.0)4. 与其他模型的对比
4.1 架构对比
| 特性 | Transformer | Mamba | Sessa |
|---|---|---|---|
| 信息流 | 前向 | 单链循环 | 多路径反馈 |
| 路径数 | 1 | 1 | T(T-1)/2 |
| 跳数 | 固定1 | 多(线性链) | 可变[1, T] |
| **长程衰减 | 指数 | ||
| 选择性检索 | 有限 | 有限 | 支持 |
| 时间复杂度 |
4.2 理论保证对比
class TheoreticalComparison:
"""
理论特性对比
"""
def compare_decay(self, t, model_type):
"""
比较不同模型的信息衰减率
Args:
t: 位置索引
model_type: 'transformer', 'mamba', 'sessa'
"""
if model_type == 'transformer':
# Transformer: O(1/t) 衰减
return 1.0 / t
elif model_type == 'mamba':
# Mamba: 指数衰减
decay_rate = 0.95 # 示例
return decay_rate ** t
elif model_type == 'sessa':
# Sessa: 幂律衰减,β ∈ (0, 1)
beta = 0.5
return t ** (-beta)
def print_comparison(self):
t = 1000
print(f"位置{t}处的信息影响力对比:")
print(f" Transformer: O({self.compare_decay(t, 'transformer'):.6f})")
print(f" Mamba: O({self.compare_decay(t, 'mamba'):.6f})")
print(f" Sessa(β=0.5): O({self.compare_decay(t, 'sessa'):.6f})")5. 完整PyTorch实现
5.1 核心模块
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import math
class SessaAttention(nn.Module):
"""
Sessa: Selective State Space Attention
核心创新:在反馈通路中嵌入自注意力,实现多跳路由
"""
def __init__(
self,
d_model: int,
n_heads: int = 8,
dropout: float = 0.1,
gamma_init: float = 0.5, # 初始反馈增益
):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.scale = math.sqrt(self.d_head)
# QKV投影
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
# 反馈增益(可学习)
self.gamma = nn.Parameter(
torch.tensor(gamma_init * torch.ones(n_heads))
)
self.dropout = nn.Dropout(dropout)
# 可选的RoPE
self.rope = None # 可添加RotaryPositionEmbedding
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor = None,
return_states: bool = False,
) -> torch.Tensor:
"""
Args:
x: [batch, seq_len, d_model]
attention_mask: 可选的注意力掩码
return_states: 是否返回中间状态
Returns:
output: [batch, seq_len, d_model]
"""
B, T, C = x.shape
# QKV投影
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
# 分解为多头
Q = rearrange(Q, 'b t (h d) -> b h t d', h=self.n_heads)
K = rearrange(K, 'b t (h d) -> b h t d', h=self.n_heads)
V = rearrange(V, 'b t (h d) -> b h t d', h=self.n_heads)
# 应用RoPE(如果使用)
if self.rope is not None:
Q, K = self.rope.rotate(Q), self.rope.rotate(K)
# ============== 前向注意力 ==============
# 计算注意力分数
attn_scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) / self.scale
# 因果掩码
causal_mask = torch.triu(
torch.ones(T, T, device=x.device, dtype=torch.bool),
diagonal=1
)
attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
# 应用外部掩码
if attention_mask is not None:
attn_scores = attn_scores.masked_fill(~attention_mask.unsqueeze(1), float('-inf'))
# softmax归一化
attn_probs = F.softmax(attn_scores, dim=-1)
attn_probs = self.dropout(attn_probs)
# 前向注意力输出
f = torch.einsum('bhqk,bhvd->bhqd', attn_probs, V)
# ============== 反馈注意力 ==============
# 构建反馈注意力矩阵 B_fb(严格过去注意)
# B_fb[t, tau] = attention(q_t, k_tau) for tau < t
# 这是下三角矩阵
# 重新组织用于反馈计算
Q_fb = rearrange(Q, 'b h t d -> b h t () d')
K_fb = rearrange(K, 'b h t d -> b h () t d')
# 计算反馈注意力矩阵(不做softmax,用于后续加权)
B_fb_raw = torch.einsum('bhqtd,bhstd->bhqt', Q_fb, K_fb) / self.scale
# 上三角置零(严格过去注意)
mask = torch.triu(
torch.ones(T, T, device=x.device, dtype=torch.bool),
diagonal=1
)
B_fb_raw = B_fb_raw.masked_fill(mask, 0.0)
# 沿key维度归一化(每个query在严格过去上归一化)
B_fb_sum = B_fb_raw.sum(dim=-1, keepdim=True) + 1e-8
B_fb = B_fb_raw / B_fb_sum
# ============== 前向替换求解 ==============
# s = (I - B_fb)^(-1) f
s = self._forward_substitution(B_fb, f)
# ============== 输出混合 ==============
# 限制gamma在(-1, 1)范围内
gamma = torch.tanh(self.gamma) # 确保有界
# 混合前向和反馈路径
output = gamma.unsqueeze(-1) * s + (1 - gamma.unsqueeze(-1)) * f
# 输出投影
output = rearrange(output, 'b h t d -> b t (h d)')
output = self.W_o(output)
if return_states:
return output, {'f': f, 's': s, 'gamma': gamma}
return output
def _forward_substitution(
self,
B_fb: torch.Tensor,
f: torch.Tensor
) -> torch.Tensor:
"""
前向替换求解 (I - B_fb) s = f
递归形式:
s[0] = f[0]
s[t] = f[t] + B_fb[t, :t] @ s[:t]
"""
B, H, T, D = f.shape
s = torch.zeros_like(f)
# s[:, :, 0, :] = f[:, :, 0, :]
s[:, :, 0, :] = f[:, :, 0, :]
# 前向替换
for t in range(1, T):
# s[t] = f[t] + B_fb[t, :t] @ s[:t]
# B_fb[:, :, t, :t] shape: [B, H, 1, t]
# s[:, :, :t, :] shape: [B, H, t, D]
# 结果 shape: [B, H, 1, D]
contribution = torch.einsum('bh1t,bhtd->bh1d', B_fb[:, :, t:t+1, :t], s[:, :, :t, :])
s[:, :, t:t+1, :] = f[:, :, t:t+1, :] + contribution
return s5.2 完整Sessa块
class SessaBlock(nn.Module):
"""
Sessa Transformer块
"""
def __init__(
self,
d_model: int,
n_heads: int = 8,
d_ff: int = None,
dropout: float = 0.1,
mlp_dropout: float = 0.1,
activation: str = 'gelu',
):
super().__init__()
d_ff = d_ff or 4 * d_model
# Sessa注意力
self.attention = SessaAttention(d_model, n_heads, dropout)
self.norm1 = nn.LayerNorm(d_model)
# FFN
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU() if activation == 'gelu' else nn.ReLU(),
nn.Dropout(mlp_dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(mlp_dropout),
)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, attention_mask=None):
# 预-norm残差连接
x = x + self.attention(self.norm1(x), attention_mask)
x = x + self.ffn(self.norm2(x))
return x5.3 FlashAttention版本(可选优化)
class SessaAttentionFlash(nn.Module):
"""
使用FlashAttention加速的Sessa实现
适用于长序列场景
"""
def __init__(self, d_model, n_heads=8, dropout=0.1):
super().__init__()
# ... 初始化同前 ...
self.flash_attn = True
def forward(self, x, attention_mask=None):
# 前向注意力使用FlashAttention
# ...
pass6. 实验结果
6.1 长程依赖基准
| 任务 | Transformer | Mamba | Sessa |
|---|---|---|---|
| PathFinder | 85.2% | 82.1% | 89.7% |
| Long Range Arena | 67.4% | 64.8% | 71.2% |
| SCAN (length) | 54.3% | 61.2% | 78.5% |
6.2 选择性检索任务
| 模型 | 精确检索 | Diffuse设置 | 非衰减检索 |
|---|---|---|---|
| Transformer | ✓ | ✗ | ✗ |
| Mamba | ✓ | ✓ | ✗ |
| Sessa | ✓ | ✓ | ✓ |
6.3 效率分析
| 模型 | 复杂度(标准) | 复杂度(稀疏) | 内存 |
|---|---|---|---|
| Transformer | 高 | ||
| Mamba | 低 | ||
| Sessa | 中 |
7. 总结
核心贡献
- 多跳路由机制:通过反馈注意力实现可变跳数的路径聚合
- 幂律衰减:理论上解决了注意力扩散和指数遗忘问题
- 稳定性保证:BIBO稳定性分析确保可靠训练和推理
- 选择性检索:支持在各种设置下的灵活信息检索
与现有工作的区别
| 特性 | Sessa | 竞争方法 |
|---|---|---|
| 路径结构 | 下三角反馈矩阵 | 单链/全连接 |
| 衰减速率 | 或指数 | |
| 检索灵活性 | 完全选择 | 受限 |
参考资料
相关专题:Mamba与SSM分析 | 线性注意力变体
Footnotes
-
Sessa: Selective State Space Attention. arXiv:2604.18580 ↩