Sessa:选择性状态空间注意力

概述

Sessa(Selective State Space Attention)是一种新型序列建模架构,通过在反馈(循环)通路中嵌入自注意力来解决传统Transformer和状态空间模型(SSM)的长程依赖建模问题。

核心创新:Sessa引入**多跳路由(Multi-hop Routing)**机制,通过反馈注意力构建可变跳数的路径,理论上实现了比Transformer更慢的遗忘和比Mamba更灵活的选择性检索。

核心论文:arXiv:2604.185801

代码实现GitHub - LibratioAI/sessa


1. 背景与动机

长上下文建模的两大挑战

Sessa指出现有序列模型面临两个互补的长程依赖失效模式

1.1 注意力扩散(Attention Diffusion)

标准Transformer的自注意力在处理长序列时会出现token影响力稀释问题:

  • 注意力分数分布在越来越多的token上
  • 早期token的梯度信号衰减为
  • 当注意力”diffuse”(分散)时,检索变得不精确
# 注意力扩散示例
class AttentionDiffusion:
    """演示注意力扩散问题"""
    
    def compute_gradient_decay(self, seq_len, n_heads):
        """
        计算早期token的梯度衰减
        
        结论:当序列长度为ℓ时,早期token梯度 ~ O(1/ℓ)
        """
        decay_rate = 1.0 / seq_len  # O(1/ℓ) 衰减
        return decay_rate
    
    def problem_description(self):
        """
        问题:当token数增加时:
        1. 注意力权重变得更分散
        2. 单个token的平均影响力下降
        3. 精确检索变得困难
        """
        pass

1.2 指数遗忘(Exponential Forgetting)

Mamba等SSM面临的问题是指数级遗忘

  • 信息通过线性时不变(LTI)系统传递
  • 旧token的信号以指数速度衰减
  • 只有在”freeze time”(冻结时间)内才能保持长程依赖
// Mamba的指数遗忘问题
// S4/SSM的前向传播
template <typename T>
T mamba_forward(T x_t, const T& A, const T& B, const T& C) {
    // 状态更新:h_t = A * h_{t-1} + B * x_t
    // 输出:y_t = C * h_t
    // 
    // 问题:A的固有特性导致指数衰减
    // 对于|A| < 1,系统呈现指数遗忘
    T h_t = A * h_prev + B * x_t;  // h_t = A * h_{t-1} + B * x_t
    T y_t = C * h_t;              // y_t = C * h_t
    return y_t;
}

现有方法的局限

模型路径数跳数长程衰减选择性检索
Transformer11有限(diffuse时失效)
Mamba1指数有限(freeze时失效)
Sessa支持

2. 核心架构:反馈注意力

2.1 关键洞察

Sessa的核心洞察是:将自注意力嵌入到反馈(循环)通路中,构建一个下三角路由矩阵 ,使得信息可以通过多条不同跳数的路径传递。

2.2 架构图示

┌─────────────────────────────────────────────────────────────────┐
│                      Sessa 架构                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│   前向注意力(Forward Attention)                               │
│   ┌─────────────────────────────────────────────────────────┐  │
│   │  Q_t = x_t @ W_Q                                        │  │
│   │  K_t = x_t @ W_K                                        │  │
│   │  V_t = x_t @ W_V                                        │  │
│   │  f_t = softmax(Q_t @ K_t^T) @ V_t  // 前向信号         │  │
│   └─────────────────────────────────────────────────────────┘  │
│                              ↓                                  │
│   反馈注意力(Feedback Attention)                              │
│   ┌─────────────────────────────────────────────────────────┐  │
│   │  // 反馈注意力构建下三角路由矩阵                         │  │
│   │  for t in range(1, T+1):                               │  │
│   │      B_fb[t, :t] = attention(Q_t, K_{:t}, V_{:t})     │  │
│   │  //  s_t 通过多跳路径聚合信息                           │  │
│   │  (I - B_fb) · s = f                                    │  │
│   └─────────────────────────────────────────────────────────┘  │
│                              ↓                                  │
│   输出: y_t = γ_t · s_t + (1-γ_t) · f_t                     │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

2.3 数学形式化

前向注意力(Forward Attention)

其中 是标准注意力query、key、value。

反馈注意力(Feedback Attention)

反馈注意力构建一个下三角矩阵

这个矩阵表示严格的过去注意——每个位置只能关注它之前的token。

核心方程

展开为:

通过前向替换求解:

def forward_substitution(B_fb, f):
    """
    前向替换求解 (I - B_fb) s = f
    
    Args:
        B_fb: 下三角反馈注意力矩阵 [T, T]
        f: 前向注意力输出 [T, d]
    
    Returns:
        s: 反馈状态序列 [T, d]
    """
    T, d = f.shape
    s = torch.zeros_like(f)
    
    # s[0] = f[0]
    s[0] = f[0]
    
    # 逐时间步求解
    for t in range(1, T):
        # s[t] = f[t] + B_fb[t, :t] @ s[:t]
        s[t] = f[t] + B_fb[t, :t] @ s[:t]
    
    return s

2.4 输出混合

最终输出通过可学习的反馈增益 混合:

其中 通过门控机制学习,控制前向和反馈路径的贡献比例。


3. 多跳路由机制

3.1 为什么需要多跳?

传统模型的问题在于单路径信息流

  • Transformer:信息通过一跳(one-hop)传递,每个token只直接与所有其他token交互一次
  • Mamba:信息通过多个时间步传递,但仍是一条链(one chain)

Sessa通过反馈注意力实现了多跳路由,信息可以通过不同跳数的路径聚合:

class MultiHopRouting:
    """
    Sessa的多跳路由机制
    
    信息可以通过以下路径到达位置t:
    - 跳数1:f_t(直接前向注意力)
    - 跳数2:B_fb[t, t-1] @ s_{t-1}
    - 跳数3:B_fb[t, t-2] @ s_{t-2}
    - ...
    - 跳数t:所有早期位置
    """
    
    def analyze_paths(self, t):
        """
        分析到达位置t的所有可能路径
        
        路径数量 = t(与到位置t的距离成正比)
        跳数范围 = [1, t]
        """
        paths = []
        for tau in range(t):  # tau是回溯的距离
            n_hops = tau + 1  # 跳数 = 回溯距离 + 1
            paths.append({
                'source': t - tau - 1,
                'target': t,
                'n_hops': n_hops,
                'route': f"f_{t-tau} → ... → s_t"
            })
        return paths

3.2 信息传递分析

def analyze_information_flow(T, beta=0.5):
    """
    分析Sessa的信息传递特性
    
    假设:注意力在严格过去上diffuse分布
    
    定理:在diffuse假设下,到位置t的信息满足:
    - 跳数分布:P(n_hops = k) ∝ (1-β)^k
    - 最终影响:O(t^{-β}) 对于β ∈ (0, 1)
    """
    print(f"序列长度T={T}的信息传递分析:")
    print("-" * 50)
    
    # 在diffuse假设下,注意力均匀分布在所有早期token上
    # 经过k次反馈后的衰减
    for beta in [0.3, 0.5, 0.7]:
        print(f"\nβ = {beta}:")
        for t in [100, 500, 1000]:
            # 晚期token的影响 ~ O(t^(-β))
            influence = t ** (-beta)
            print(f"  位置{t}的影响: O({influence:.4f})")
    
    print("\n结论:比Transformer的O(1/t)和Mamba的指数衰减都慢!")

3.3 稳定性保证

Sessa通过限制反馈增益 保证BIBO(有界输入有界输出)稳定性

class StabilityGuarantee:
    """
    Sessa的BIBO稳定性分析
    
    系统:(I - B_fb) s = f
    等价于:s = (I - B_fb)^(-1) f
    
    稳定性条件:
    |γ_t| < 1, ∀t
    
    这保证了:
    1. 系统不会发散
    2. 输入有界 → 输出有界
    """
    
    @staticmethod
    def check_stability(gamma):
        """
        检查系统稳定性
        
        Returns:
            bool: 是否满足BIBO稳定性
        """
        return torch.all(torch.abs(gamma) < 1.0)

4. 与其他模型的对比

4.1 架构对比

特性TransformerMambaSessa
信息流前向单链循环多路径反馈
路径数11T(T-1)/2
跳数固定1多(线性链)可变[1, T]
**长程衰减指数
选择性检索有限有限支持
时间复杂度

4.2 理论保证对比

class TheoreticalComparison:
    """
    理论特性对比
    """
    
    def compare_decay(self, t, model_type):
        """
        比较不同模型的信息衰减率
        
        Args:
            t: 位置索引
            model_type: 'transformer', 'mamba', 'sessa'
        """
        if model_type == 'transformer':
            # Transformer: O(1/t) 衰减
            return 1.0 / t
        
        elif model_type == 'mamba':
            # Mamba: 指数衰减
            decay_rate = 0.95  # 示例
            return decay_rate ** t
        
        elif model_type == 'sessa':
            # Sessa: 幂律衰减,β ∈ (0, 1)
            beta = 0.5
            return t ** (-beta)
    
    def print_comparison(self):
        t = 1000
        print(f"位置{t}处的信息影响力对比:")
        print(f"  Transformer: O({self.compare_decay(t, 'transformer'):.6f})")
        print(f"  Mamba:      O({self.compare_decay(t, 'mamba'):.6f})")
        print(f"  Sessa(β=0.5): O({self.compare_decay(t, 'sessa'):.6f})")

5. 完整PyTorch实现

5.1 核心模块

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import math
 
class SessaAttention(nn.Module):
    """
    Sessa: Selective State Space Attention
    
    核心创新:在反馈通路中嵌入自注意力,实现多跳路由
    """
    
    def __init__(
        self,
        d_model: int,
        n_heads: int = 8,
        dropout: float = 0.1,
        gamma_init: float = 0.5,  # 初始反馈增益
    ):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.scale = math.sqrt(self.d_head)
        
        # QKV投影
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        # 反馈增益(可学习)
        self.gamma = nn.Parameter(
            torch.tensor(gamma_init * torch.ones(n_heads))
        )
        
        self.dropout = nn.Dropout(dropout)
        
        # 可选的RoPE
        self.rope = None  # 可添加RotaryPositionEmbedding
    
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor = None,
        return_states: bool = False,
    ) -> torch.Tensor:
        """
        Args:
            x: [batch, seq_len, d_model]
            attention_mask: 可选的注意力掩码
            return_states: 是否返回中间状态
        
        Returns:
            output: [batch, seq_len, d_model]
        """
        B, T, C = x.shape
        
        # QKV投影
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # 分解为多头
        Q = rearrange(Q, 'b t (h d) -> b h t d', h=self.n_heads)
        K = rearrange(K, 'b t (h d) -> b h t d', h=self.n_heads)
        V = rearrange(V, 'b t (h d) -> b h t d', h=self.n_heads)
        
        # 应用RoPE(如果使用)
        if self.rope is not None:
            Q, K = self.rope.rotate(Q), self.rope.rotate(K)
        
        # ============== 前向注意力 ==============
        # 计算注意力分数
        attn_scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) / self.scale
        
        # 因果掩码
        causal_mask = torch.triu(
            torch.ones(T, T, device=x.device, dtype=torch.bool), 
            diagonal=1
        )
        attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
        
        # 应用外部掩码
        if attention_mask is not None:
            attn_scores = attn_scores.masked_fill(~attention_mask.unsqueeze(1), float('-inf'))
        
        # softmax归一化
        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.dropout(attn_probs)
        
        # 前向注意力输出
        f = torch.einsum('bhqk,bhvd->bhqd', attn_probs, V)
        
        # ============== 反馈注意力 ==============
        # 构建反馈注意力矩阵 B_fb(严格过去注意)
        # B_fb[t, tau] = attention(q_t, k_tau) for tau < t
        # 这是下三角矩阵
        
        # 重新组织用于反馈计算
        Q_fb = rearrange(Q, 'b h t d -> b h t () d')
        K_fb = rearrange(K, 'b h t d -> b h () t d')
        
        # 计算反馈注意力矩阵(不做softmax,用于后续加权)
        B_fb_raw = torch.einsum('bhqtd,bhstd->bhqt', Q_fb, K_fb) / self.scale
        
        # 上三角置零(严格过去注意)
        mask = torch.triu(
            torch.ones(T, T, device=x.device, dtype=torch.bool),
            diagonal=1
        )
        B_fb_raw = B_fb_raw.masked_fill(mask, 0.0)
        
        # 沿key维度归一化(每个query在严格过去上归一化)
        B_fb_sum = B_fb_raw.sum(dim=-1, keepdim=True) + 1e-8
        B_fb = B_fb_raw / B_fb_sum
        
        # ============== 前向替换求解 ==============
        # s = (I - B_fb)^(-1) f
        s = self._forward_substitution(B_fb, f)
        
        # ============== 输出混合 ==============
        # 限制gamma在(-1, 1)范围内
        gamma = torch.tanh(self.gamma)  # 确保有界
        
        # 混合前向和反馈路径
        output = gamma.unsqueeze(-1) * s + (1 - gamma.unsqueeze(-1)) * f
        
        # 输出投影
        output = rearrange(output, 'b h t d -> b t (h d)')
        output = self.W_o(output)
        
        if return_states:
            return output, {'f': f, 's': s, 'gamma': gamma}
        
        return output
    
    def _forward_substitution(
        self, 
        B_fb: torch.Tensor, 
        f: torch.Tensor
    ) -> torch.Tensor:
        """
        前向替换求解 (I - B_fb) s = f
        
        递归形式:
        s[0] = f[0]
        s[t] = f[t] + B_fb[t, :t] @ s[:t]
        """
        B, H, T, D = f.shape
        s = torch.zeros_like(f)
        
        # s[:, :, 0, :] = f[:, :, 0, :]
        s[:, :, 0, :] = f[:, :, 0, :]
        
        # 前向替换
        for t in range(1, T):
            # s[t] = f[t] + B_fb[t, :t] @ s[:t]
            # B_fb[:, :, t, :t] shape: [B, H, 1, t]
            # s[:, :, :t, :] shape: [B, H, t, D]
            # 结果 shape: [B, H, 1, D]
            contribution = torch.einsum('bh1t,bhtd->bh1d', B_fb[:, :, t:t+1, :t], s[:, :, :t, :])
            s[:, :, t:t+1, :] = f[:, :, t:t+1, :] + contribution
        
        return s

5.2 完整Sessa块

class SessaBlock(nn.Module):
    """
    Sessa Transformer块
    """
    
    def __init__(
        self,
        d_model: int,
        n_heads: int = 8,
        d_ff: int = None,
        dropout: float = 0.1,
        mlp_dropout: float = 0.1,
        activation: str = 'gelu',
    ):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        
        # Sessa注意力
        self.attention = SessaAttention(d_model, n_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU() if activation == 'gelu' else nn.ReLU(),
            nn.Dropout(mlp_dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(mlp_dropout),
        )
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x, attention_mask=None):
        # 预-norm残差连接
        x = x + self.attention(self.norm1(x), attention_mask)
        x = x + self.ffn(self.norm2(x))
        return x

5.3 FlashAttention版本(可选优化)

class SessaAttentionFlash(nn.Module):
    """
    使用FlashAttention加速的Sessa实现
    适用于长序列场景
    """
    
    def __init__(self, d_model, n_heads=8, dropout=0.1):
        super().__init__()
        # ... 初始化同前 ...
        self.flash_attn = True
    
    def forward(self, x, attention_mask=None):
        # 前向注意力使用FlashAttention
        # ...
        pass

6. 实验结果

6.1 长程依赖基准

任务TransformerMambaSessa
PathFinder85.2%82.1%89.7%
Long Range Arena67.4%64.8%71.2%
SCAN (length)54.3%61.2%78.5%

6.2 选择性检索任务

模型精确检索Diffuse设置非衰减检索
Transformer
Mamba
Sessa

6.3 效率分析

模型复杂度(标准)复杂度(稀疏)内存
Transformer
Mamba
Sessa

7. 总结

核心贡献

  1. 多跳路由机制:通过反馈注意力实现可变跳数的路径聚合
  2. 幂律衰减:理论上解决了注意力扩散和指数遗忘问题
  3. 稳定性保证:BIBO稳定性分析确保可靠训练和推理
  4. 选择性检索:支持在各种设置下的灵活信息检索

与现有工作的区别

特性Sessa竞争方法
路径结构下三角反馈矩阵单链/全连接
衰减速率或指数
检索灵活性完全选择受限

参考资料


相关专题Mamba与SSM分析 | 线性注意力变体

Footnotes

  1. Sessa: Selective State Space Attention. arXiv:2604.18580