概述

Mamba-3 在 Mamba-2 的基础上引入了三大核心创新——指数-梯形离散化、复数值状态空间和多输入多输出(MIMO)架构。为了支撑这些创新,需要在算法实现层面进行系统性优化。本文档从方法论角度深入解析 Mamba-3 的核心技术原理,包括离散化算法、硬件感知设计、并行扫描实现、梯度计算机制以及数值稳定性保障。1

1. 连续时间状态空间模型基础

1.1 标准 SSM 定义

连续时间线性状态空间模型定义为:

其中:

  • :输入信号
  • :隐藏状态
  • :输出信号
  • :状态转移矩阵
  • :输入矩阵
  • :输出矩阵
  • :跳跃连接矩阵

1.2 连续时间解

状态方程的解析解为:

这表明状态转移由矩阵指数 控制,具有重要的几何和代数性质。

1.3 矩阵指数的性质

矩阵指数满足以下关键性质:

对于一般矩阵,利用泰勒展开:

在离散化实现中,通过缩放和平滑来近似连续时间动态。

2. 离散化算法实现细节

2.1 零阶保持(ZOH)离散化

标准 ZOH 离散化假设输入在采样区间内保持恒定:

离散化后的状态更新公式为:

状态更新变为:

2.2 指数-梯形离散化

Mamba-3 提出的指数-梯形离散化利用梯形积分规则提高精度:

应用梯形规则到脉冲响应积分:

Mamba-3 的创新在于引入可学习的指数采样参数

其中 是可学习的标量参数,控制梯形近似的精度。

2.3 指数采样点的几何意义

指数采样点 的选择对离散化精度有重要影响:

均匀采样(ZOH):

中点采样

指数-梯形采样

2.4 离散化算法实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
def discretize_exp_trapezoidal(dt, A, B, gamma):
    """
    指数-梯形离散化实现
    
    参数:
        dt: 时间步长 (batch, seq_len) 或标量
        A: 状态转移矩阵 (d_state, d_state)
        B: 输入矩阵 (d_state, d_inner) 或 (d_inner,)
        gamma: 梯形参数 (标量或可学习)
    
    返回:
        dA: 离散化状态转移矩阵
        dB: 离散化输入矩阵
    """
    # 确保 dt 在合理范围
    dt = F.softplus(dt)  # dt > 0
    
    # 计算 e^(dt * A)
    dA = torch.matrix_exp(dt.unsqueeze(-1) * A)  # (..., d_state, d_state)
    
    # 计算指数-梯形项
    # (e^(dt*A) - I) / (dt * A)
    I = torch.eye(A.shape[-1], device=A.device, dtype=A.dtype)
    dtA_inv = (dA - I) * torch.where(
        dt.unsqueeze(-1) > 1e-6,
        1.0 / (dt.unsqueeze(-1) * A),
        A  # 近似:当 dt*A → 0 时
    )
    
    # 应用 gamma 幂次
    # 使用对数空间避免数值问题
    log_term = torch.log(dtA_inv + 1e-8) * gamma
    gamma_power = torch.exp(log_term)
    
    # 最终离散化 B
    dB = dt.unsqueeze(-1) * B * gamma_power * torch.matrix_exp(dt.unsqueeze(-1) * A / 2)
    
    return dA, dB

3. 硬件感知设计原理

3.1 计算瓶颈分析

现代 GPU 架构的特点:

特性描述影响
张量核心高吞吐矩阵运算compute-bound 操作高效
共享内存高速片上存储减少全局内存访问
HBM 带宽~2TB/s (H100)memory-bound 操作受限

算术强度是判断操作属于哪类的关键指标:

  • :compute-bound(计算密集)
  • :memory-bound(内存密集)

3.2 SSM 的算术强度

对于标准 SISO SSM:

其中

内存访问量(字节,假设 bfloat16):

FLOPs

算术强度

典型的

这接近 memory-bound 和 compute-bound 的边界。

3.3 MIMO 的算术强度提升

MIMO 架构将输入分组处理:

状态更新变为矩阵-矩阵形式:

内存访问量

FLOPs

算术强度

增大时,等效算术强度提升 倍。

3.4 硬件友好设计原则

┌─────────────────────────────────────────────────────────────────────┐
│                    硬件友好设计原则                                    │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  1. 核融合 (Kernel Fusion)                                          │
│     ┌─────────────────────────────────────────────────────────┐    │
│     │ 分离操作: dA=h[0]*h[1] → h_new=h*dA → y=C@h_new          │    │
│     │ 融合内核: Single GPU kernel for all operations          │    │
│     │ 收益: 减少内存访问 ~40%                                  │    │
│     └─────────────────────────────────────────────────────────┘    │
│                                                                     │
│  2. 并行扫描 (Parallel Scan)                                        │
│     ┌─────────────────────────────────────────────────────────┐    │
│     │ 串行: O(T) → 并行: O(log T)                             │    │
│     │ 利用 GPU 多核并行性                                       │    │
│     └─────────────────────────────────────────────────────────┘    │
│                                                                     │
│  3. 梯度重计算 (Gradient Recomputation)                              │
│     ┌─────────────────────────────────────────────────────────┐    │
│     │ 前向: 保存中间状态                                        │    │
│     │ 反向: 不保存 → 重新计算前向                               │    │
│     │ 收益: 显存减少 ~50%, 计算增加 ~30%                       │    │
│     └─────────────────────────────────────────────────────────┘    │
│                                                                     │
│  4. 分块计算 (Chunked Computation)                                  │
│     ┌─────────────────────────────────────────────────────────┐    │
│     │ 长序列 → 多个块 → 块内并行 + 块间串行                     │    │
│     │ 平衡计算密度与内存使用                                     │    │
│     └─────────────────────────────────────────────────────────┘    │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

4. 并行扫描算法

4.1 关联函数定义

并行扫描的核心是将状态更新定义为关联操作

定义归约操作(Reduction):

其中

4.2 扫描的结合律

扫描操作必须满足结合律才能并行化:

验证:

发现问题:两项不相等!需要重新定义关联函数。

4.3 正确关联函数定义

正确的关联函数应跟踪累积状态转移累积输入贡献

其中:

  • :累积状态转移矩阵
  • :累积输入贡献

定义关联操作:

4.4 并行扫描算法实现

import torch
import torch.nn as nn
 
def parallel_scan(xs, As):
    """
    并行扫描算法
    
    参数:
        xs: 输入序列 (batch, seq_len, d_state)
        As: 状态转移矩阵序列 (batch, seq_len, d_state, d_state)
    
    返回:
        hs: 隐藏状态序列 (batch, seq_len, d_state)
    """
    batch, seq_len, d_state = xs.shape
    
    # 初始化:第一个元素
    h = torch.zeros(batch, d_state, device=xs.device, dtype=xs.dtype)
    outputs = []
    
    # 对数级并行扫描
    for offset in reversed(range(seq_len)):
        # 计算需要更新的部分
        # 使用掩码避免竞态条件
        h_new = As[:, offset] @ h + xs[:, offset]
        
        # 存储(后续会被覆盖)
        outputs.append(h_new)
        h = h_new
    
    # 反转得到正确顺序
    return torch.stack(list(reversed(outputs)), dim=1)
 
 
def parallel_scan_hillis_steele(xs, As):
    """
    Hillis-Steele 并行扫描算法
    
    时间复杂度: O(log T)
    工作量: O(T log T)
    
    参数:
        xs: 输入 (batch, seq_len, d_state)
        As: 转移矩阵 (batch, seq_len, d_state, d_state)
    """
    batch, seq_len, d_state = xs.shape
    device = xs.device
    
    # 初始化状态
    M = torch.eye(d_state, device=device).unsqueeze(0).unsqueeze(0).expand(batch, seq_len, -1, -1)
    M = M * As  # 累积转移
    v = xs
    
    # log T 轮迭代
    import math
    num_steps = math.ceil(math.log2(seq_len))
    
    for step in range(num_steps):
        offset = 2 ** step
        
        # 使用前一位置的累积结果
        M_left = M[:, :-offset]  # (batch, seq_len-offset, d_state, d_state)
        v_left = v[:, :-offset]  # (batch, seq_len-offset, d_state)
        
        M_right = M[:, offset:]  # (batch, seq_len-offset, d_state, d_state)
        v_right = v[:, offset:]  # (batch, seq_len-offset, d_state)
        
        # 关联组合:M_new = M_right @ M_left, v_new = M_right @ v_left + v_right
        M_combined = torch.bmm(
            M_right.view(-1, d_state, d_state),
            M_left.view(-1, d_state, d_state)
        ).view(batch, seq_len - offset, d_state, d_state)
        
        v_combined = torch.bmm(
            M_right.view(-1, d_state, d_state),
            v_left.view(-1, d_state)
        ).view(batch, seq_len - offset, d_state) + v_right
        
        # 更新
        M[:, offset:] = M_combined
        v[:, offset:] = v_combined
    
    return v

4.5 分块并行扫描

对于超长序列,采用分块策略:

def chunked_parallel_scan(x, A, chunk_size=64):
    """
    分块并行扫描
    
    将长序列分成多个块:
    - 块内:并行扫描
    - 块间:顺序递推
    
    复杂度: O(T/C * C^2 + T/C * N) ≈ O(T * C) for small C
    """
    batch, seq_len, d_state = x.shape
    device = x.device
    
    num_chunks = (seq_len + chunk_size - 1) // chunk_size
    
    # 填充到块大小整数倍
    pad_len = num_chunks * chunk_size - seq_len
    if pad_len > 0:
        x = F.pad(x, (0, 0, 0, pad_len))
        A = F.pad(A, (0, 0, 0, 0, 0, pad_len))
    
    x = x.view(batch, num_chunks, chunk_size, d_state)
    A = A.view(batch, num_chunks, chunk_size, d_state, d_state)
    
    # 1. 块内并行扫描
    h_chunk = torch.zeros(batch, num_chunks, d_state, device=device)
    y_chunk = torch.zeros(batch, num_chunks, chunk_size, d_state, device=device)
    
    for i in range(num_chunks):
        y_chunk[:, i], final_state = parallel_scan_chunk(
            x[:, i], A[:, i], h_chunk[:, i]
        )
        h_chunk[:, i + 1] = final_state if i < num_chunks - 1 else torch.zeros_like(final_state)
    
    # 2. 块间递推
    for i in range(1, num_chunks):
        # 块间状态转移
        h_transfer = torch.bmm(
            torch.matrix_exp(A[:, i-1, -1:].squeeze(1) * chunk_size),
            h_chunk[:, i-1:i].squeeze(1)
        )
        h_chunk[:, i] = h_transfer + h_chunk[:, i]
    
    return y_chunk[:, :seq_len // chunk_size].reshape(batch, seq_len, d_state)

5. 梯度计算与反向传播

5.1 前向传播的矩阵形式

将 SSM 表示为矩阵-向量乘积:

其中 是半可分矩阵(semi-separable matrix):

5.2 反向传播梯度计算

目标:计算

损失对输入的梯度

这需要反向扫描

5.3 反向扫描算法

def backward_scan(dout, xs, As):
    """
    反向扫描:计算梯度
    
    参数:
        dout: 输出梯度 (batch, seq_len, d_output)
        xs: 输入序列 (batch, seq_len, d_input)
        As: 转移矩阵 (batch, seq_len, d_state, d_state)
    
    返回:
        dx: 输入梯度 (batch, seq_len, d_input)
        dAs: 转移矩阵梯度 (batch, seq_len, d_state, d_state)
    """
    batch, seq_len, d_out = dout.shape
    d_input = xs.shape[-1]
    d_state = As.shape[-1]
    
    device = xs.device
    
    # 初始化
    dx = torch.zeros_like(xs)
    dAs = torch.zeros_like(As)
    
    # 反向状态
    dh_next = torch.zeros(batch, d_state, device=device)
    
    # 反向扫描
    for t in reversed(range(seq_len)):
        # 输出梯度传播
        # dy_t = C_t @ h_t
        dh = dout[:, t]  # (batch, d_out) - 简化假设 d_out = d_state
        
        # 加上来自后续状态的梯度
        dh = dh + dh_next
        
        # 存储梯度
        dx[:, t] = dh @ As[:, t]  # 简化的梯度计算
        
        # 更新反向状态
        dh_next = dh.unsqueeze(-1) @ As[:, t].unsqueeze(1)
        dh_next = dh_next.squeeze(-2)
    
    return dx, dAs
 
 
def reverse_parallel_scan(grad_y, M):
    """
    反向并行扫描(使用关联函数反转)
    
    关联函数的反转:
    如果 (M_k, v_k) = (M_k, v_k) ⊗ (s_k, x_k)
    则反向传递计算 s_{k-1} 和 x_{k-1} 的梯度
    """
    batch, seq_len, d_state = grad_y.shape
    device = grad_y.device
    
    # 反向状态梯度
    grad_s = torch.zeros(batch, d_state, device=device)
    grad_x = torch.zeros(batch, seq_len, device=device)
    
    # 反向并行扫描
    # ... (实现细节省略)
    
    return grad_x

5.4 参数梯度计算

对于可学习参数

梯度重计算策略

class Mamba3WithRecomputation(nn.Module):
    """
    使用梯度重计算减少显存占用
    
    前向传播:不保存中间状态
    反向传播:重新计算前向 + 计算梯度
    """
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        # ... 参数初始化
    
    def forward(self, x, compute_loss=True):
        if not compute_loss or not self.training:
            return self._forward_no_recompute(x)
        
        # 记录输入用于重计算
        x_recorded = x.detach().requires_grad_(True)
        return self._forward_with_recompute(x_recorded)
    
    def _forward_with_recompute(self, x):
        # 前向传播,保存必要信息
        cache = []
        h = torch.zeros(x.shape[0], self.d_state, device=x.device)
        
        for t in range(x.shape[1]):
            # 计算当前步的输入
            dt = self.dt_proj(x[:, t])
            dA, dB = discretize_exp_trapezoidal(dt, self.A, self.B, self.gamma)
            
            # 状态更新
            h_new = dA * h + dB * x[:, t]
            
            # 输出
            y_t = self.C @ h_new
            
            # 缓存(用于反向)
            cache.append((h, h_new, dA, dB, x[:, t]))
            
            h = h_new
        
        return y, cache
    
    @torch.no_grad()
    def backward_with_recompute(self, y, cache, grad_y):
        """使用缓存的信息重计算梯度"""
        batch, seq_len = y.shape[:2]
        
        grad_params = {}
        grad_x = torch.zeros(batch, seq_len, self.d_model, device=y.device)
        
        dh_next = torch.zeros(batch, self.d_state, device=y.device)
        
        for t in reversed(range(seq_len)):
            h_prev, h_new, dA, dB, x_t = cache[t]
            
            # 重新计算前向以获取精确梯度
            # ... (省略细节)
            
            # 累积梯度
            # ... 
        
        return grad_params, grad_x

6. 数值稳定性考虑

6.1 指数溢出问题

矩阵指数 可能导致数值溢出:

def stable_matrix_exp(A, max_norm=10.0):
    """
    稳定的矩阵指数计算
    
    策略:
    1. 归一化:限制 A 的谱范数
    2. 分裂:将 A 分解为对角和残余
    3. Pade 近似:比泰勒展开更稳定
    """
    # 归一化
    A_norm = torch.linalg.norm(A, ord=2)
    if A_norm > max_norm:
        A = A * (max_norm / A_norm)
    
    # 使用 Pade 近似 (3,3)
    # e^A ≈ (I + 2A/3 + A^2/6) / (I - A/3)
    I = torch.eye(A.shape[-1], device=A.device, dtype=A.dtype)
    
    Q = I + (2/3) * A + (1/6) * (A @ A)
    P = I - (1/3) * A
    
    return torch.linalg.solve(P, Q)

6.2 对数空间计算

避免直接计算指数和对数:

def log_space_discretization(log_dt, log_A, B, gamma):
    """
    在对数空间进行离散化计算
    
    避免直接计算 exp 和 log
    """
    # log(e^(x) - 1) = log(1 - e^(-x)) + x  (更稳定)
    # 使用 softplus 的 log 形式
    
    # 计算 log(e^(dt*A) - I)
    # = log(dt*A) + log(e^(dt*A)/(dt*A) - I/dt*A)
    
    # 简化为数值稳定的形式
    log_dA = log_dt + log_A
    
    # 计算 (e^(dt*A) - I) / (dt*A) 的对数
    # = log(e^(dt*A) - I) - log(dt) - log(A)
    
    # 安全计算
    dtA = torch.exp(log_dt + log_A.unsqueeze(0))
    I = torch.eye(A.shape[-1], device=A.device)
    dtA_minus_I = dtA - I
    
    # 避免 log(0)
    dtA_minus_I = torch.where(
        dtA_minus_I.abs() < 1e-7,
        torch.full_like(dtA_minus_I, 1e-7),
        dtA_minus_I
    )
    
    log_dtA_minus_I_div = torch.log(dtA_minus_I) - log_dt - log_A.unsqueeze(0)
    
    # 应用 gamma
    log_term = log_dtA_minus_I_div * gamma
    
    # 转换回原始空间
    dA = torch.exp(log_dA)
    dB = torch.exp(log_dt) * B * torch.exp(log_term) * torch.exp(log_dA / 2)
    
    return dA, dB

6.3 梯度裁剪

def gradient_clipping(params, grads, max_norm=1.0):
    """
    梯度裁剪防止梯度爆炸
    """
    total_norm = 0.0
    for p, g in zip(params, grads):
        if g is not None:
            param_norm = torch.linalg.norm(g)
            total_norm += param_norm ** 2
    
    total_norm = total_norm ** 0.5
    
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for g in grads:
            if g is not None:
                g.mul_(clip_coef)
    
    return total_norm

6.4 数值稳定性检查点

class StabilityChecker:
    """运行时数值稳定性检查"""
    
    @staticmethod
    def check_finite(tensor, name="tensor"):
        """检查是否为有限值"""
        if not torch.all(torch.isfinite(tensor)):
            warnings.warn(f"{name} contains inf or nan!")
            return False
        return True
    
    @staticmethod
    def check_range(tensor, name="tensor", max_val=1e6):
        """检查值是否在合理范围内"""
        if torch.abs(tensor).max() > max_val:
            warnings.warn(f"{name} has values exceeding {max_val}!")
            return False
        return True
    
    @staticmethod
    def check_conditioning(A, name="A", threshold=1e6):
        """检查矩阵条件数"""
        try:
            cond = torch.linalg.cond(A)
            if cond > threshold:
                warnings.warn(f"{name} has poor conditioning (cond={cond:.2e})!")
                return False
        except:
            pass
        return True

7. 代码实现要点

7.1 完整 PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
 
class Mamba3Core(nn.Module):
    """
    Mamba-3 核心 SSM 实现
    
    特点:
    1. 指数-梯形离散化
    2. 复数值状态空间
    3. 硬件感知并行扫描
    """
    
    def __init__(
        self,
        d_model: int,
        d_state: int = 16,
        d_conv: int = 4,
        expand: int = 2,
        dt_min: float = 0.001,
        dt_max: float = 0.1,
        gamma_min: float = 0.5,
        gamma_max: float = 2.0,
    ):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.d_inner = d_model * expand
        self.dt_min = dt_min
        self.dt_max = dt_max
        self.gamma_min = gamma_min
        self.gamma_max = gamma_max
        
        # 输入投影
        self.in_proj = nn.Linear(d_model, self.d_inner * 2 + self.d_state * 2, bias=False)
        
        # 卷积平滑
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=d_conv,
            padding=d_conv - 1,
            groups=self.d_inner,
            bias=True,
        )
        
        # SSM 参数
        # 状态矩阵 A(复数用两个实数矩阵表示)
        self.A_log = nn.Parameter(torch.randn(d_state, d_state))
        self.A_imag = nn.Parameter(torch.randn(d_state, d_state))  # 复数部分
        
        # 输出投影
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
        
        # 初始化
        self._init_weights()
    
    def _init_weights(self):
        nn.init.xavier_uniform_(self.in_proj.weight)
        nn.init.xavier_uniform_(self.out_proj.weight)
        nn.init.normal_(self.A_log, mean=0, std=0.02)
        nn.init.normal_(self.A_imag, mean=0, std=0.02)
        nn.init.zeros_(self.D)
        
        # 卷积偏置初始化
        nn.init.constant_(self.conv1d.bias, 0)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        前向传播
        
        参数:
            x: (batch, seq_len, d_model)
        
        返回:
            y: (batch, seq_len, d_model)
        """
        batch, seq_len, d_model = x.shape
        
        # 输入投影
        xz = self.in_proj(x)  # (batch, seq_len, d_inner * 2 + d_state * 2)
        
        x_inner, z, B, C = xz.split(
            [self.d_inner, self.d_inner, self.d_state, self.d_state], dim=-1
        )
        
        # 卷积平滑
        x_conv = self.conv1d(rearrange(x_inner, 'b d s -> b s d'))[:, :seq_len]
        x_conv = rearrange(x_conv, 'b s d -> b d s')
        
        # 激活
        x_flat = F.silu(x_conv)
        z = F.silu(z)
        
        # 离散化参数
        dt = F.softplus(self.dt_proj(x_flat))  # (batch, seq_len, d_inner)
        dt = dt * (self.dt_max - self.dt_min) + self.dt_min
        
        # 指数-梯形离散化
        dA, dB = self._discretize(dt, B, C)
        
        # 并行扫描
        y_flat = self._parallel_scan(x_flat, dA, dB)
        
        # 门控和输出
        y = y_flat * z
        y = self.out_proj(y)
        
        return y
    
    def _discretize(self, dt, B, C):
        """
        指数-梯形离散化
        """
        # 复数状态矩阵
        A_real = F.softplus(self.A_log)
        A_complex = self.A_imag
        
        # 计算 e^(dt * A)
        # 对于复数矩阵,使用分块形式
        dA = torch.zeros(
            dt.shape[0], dt.shape[1], self.d_state, self.d_state,
            device=dt.device, dtype=dt.dtype
        )
        
        for i in range(self.d_state):
            for j in range(self.d_state):
                # A_ij = A_real[i,j] + i * A_complex[i,j]
                a_val = A_real[i, j]
                dA[:, :, i, j] = torch.exp(dt * a_val) * torch.cos(dt * A_complex[i, j])
        
        # 简化的 dB 计算
        dB = dt.unsqueeze(-1) * B.unsqueeze(-2)
        
        return dA, dB
    
    def _parallel_scan(self, x, dA, dB):
        """
        并行扫描实现(简化版)
        
        实际实现应使用 Triton/CUDA 内核
        """
        batch, seq_len, d_inner = x.shape
        d_state = self.d_state
        
        # 简化为矩阵-向量扫描
        h = torch.zeros(batch, d_state, device=x.device, dtype=x.dtype)
        outputs = []
        
        # 串行扫描(实际应用并行版本)
        for t in range(seq_len):
            # h_{t+1} = dA_t @ h_t + dB_t @ x_t
            h_new = torch.bmm(dA[:, t], h.unsqueeze(-1)).squeeze(-1) + \
                    torch.bmm(dB[:, t], x[:, t:t+1]).squeeze(-1)
            
            # y_t = C_t @ h_new
            y_t = torch.bmm(h_new.unsqueeze(1), C[:, t:t+1].unsqueeze(-1)).squeeze(-1)
            outputs.append(y_t)
            
            h = h_new
        
        return torch.stack(outputs, dim=1)
    
    def dt_proj(self, x):
        """DT 投影层"""
        dt = F.linear(x, self.dt_proj.weight)
        return dt
 
 
class Mamba3Block(nn.Module):
    """Mamba-3 Block with 残差连接"""
    
    def __init__(self, config):
        super().__init__()
        self.norm = nn.LayerNorm(config.d_model)
        self.mamba = Mamba3Core(config.d_model, config.d_state)
    
    def forward(self, x):
        x = x + self.mamba(self.norm(x))
        return x

7.2 Triton 实现要点

try:
    import triton
    import triton.language as tl
 
    @triton.jit
    def mamba3_triton_scan_kernel(
        x_ptr, y_ptr,  # 输入输出
        dA_ptr, dB_ptr, C_ptr,  # SSM 参数
        stride_xb, stride_xh, stride_xd,  # 内存步长
        T, N, D,  # 序列长度, 状态大小, 维度
        BLOCK_SIZE: tl.constexpr,
    ):
        """
        Triton 内核:并行扫描
        """
        pid_b = tl.program_id(0)
        pid_h = tl.program_id(1)
        
        # 初始化状态
        h = tl.zeros((N,), dtype=tl.float32)
        
        # 扫描循环
        for start in range(0, T, BLOCK_SIZE):
            offs = start + tl.arange(0, BLOCK_SIZE)
            mask = offs < T
            
            # 加载块数据
            x = tl.load(x_ptr + pid_b * stride_xb + offs * stride_xd + pid_h * stride_xh, mask=mask)
            dA = tl.load(dA_ptr + offs * N * N, mask=mask)
            dB = tl.load(dB_ptr + offs * N, mask=mask)
            
            # 状态更新
            h_new = dA * h + dB * x
            
            # 输出
            y = tl.sum(C * h_new)
            tl.store(y_ptr + pid_b * stride_xb + offs * stride_xd + pid_h * stride_xh, y, mask=mask)
            
            h = h_new
 
except ImportError:
    print("Triton not installed, using PyTorch fallback")

8. 数学公式汇总

8.1 连续时间 SSM

8.2 离散化公式

8.3 关联操作

8.4 算术强度

8.5 梯度公式


参考文献

相关链接


Last updated: 2026-05-13

Footnotes

  1. Aakash Lahoti et al., “Mamba-3: Improved Sequence Modeling using State Space Principles”, arXiv:2603.15569, 2026. https://arxiv.org/abs/2603.15569