Titans:测试时记忆学习架构

1. 背景与问题

长期序列建模面临记忆与效率的根本矛盾

架构优势劣势
Transformer完美回忆、精确依赖建模复杂度,长上下文成本爆炸
RNN/LSTM线性复杂度信息压缩丢失,难以捕获长期依赖
Mamba SSM线性复杂度+状态空间固定状态容量,信息选择性有限

核心问题:如何在保持Transformer建模能力的同时,实现对超长上下文的高效处理?

2. Titans核心思想

Titans将记忆从被动存储转变为主动的可学习组件

核心洞察:让模型在测试时持续学习”什么值得记忆”

2.1 三层记忆结构

┌─────────────────────────────────────────────────────┐
│          Temporal Memory (时间记忆)                   │
│    滑动窗口注意力,处理当前局部上下文 $O(W)$          │
├─────────────────────────────────────────────────────┤
│       Persistent Memory (持久记忆)                   │
│    缓慢更新的全局知识,类似世界模型                  │
├─────────────────────────────────────────────────────┤
│     Neural Long-Term Memory (神经长期记忆)           │
│    深度MLP模块,测试时在线学习                       │
└─────────────────────────────────────────────────────┘

3. 神经记忆模块详解

3.1 记忆矩阵表示

将记忆建模为可学习的矩阵

// 记忆的核心数据结构
class NeuralMemory {
    Matrix memory;        // 可学习参数矩阵 d × d
    Vector momentum;      // 动量项,用于稳定更新
    float decay_rate;    // 遗忘率 α_t
    
    // 记忆写入:梯度下降 + 动量
    void update(Vector s_t, float lr, float momentum) {
        // 计算梯度
        grad = compute_gradient(s_t);
        // 动量更新
        momentum = momentum * momentum + grad;
        // 记忆更新:指数衰减旧记忆 + 新信息
        memory = (1 - decay_rate) * memory + lr * momentum;
    }
}

3.2 记忆更新算法

采用指数衰减+动量的记忆更新机制:

其中衰减项 通过带动量的梯度下降计算:

  • :自适应遗忘门控
  • :动量系数
  • :学习率

3.3 门控机制

自适应遗忘门控是记忆效率的关键:

// 门控实现
Tensor gated_update(Tensor residual, Tensor branch_out, Tensor gates) {
    // Sigmoid门控控制新信息流入比例
    return residual.lerp(branch_out, gates.sigmoid());
}

门控机制允许模型选择性遗忘,避免记忆饱和。

4. Titans架构变体

4.1 Memory as Context (MAC)

核心思想:将记忆矩阵展开为”伪token”,拼接在输入序列前

输入序列: [token_1, token_2, ..., token_n]
记忆展开:  [mem_1, mem_2, ..., mem_k, token_1, token_2, ..., token_n]

// 通过标准注意力访问记忆
output = MultiHeadAttention(query=[tokens], key=[记忆+tokens], value=[记忆+tokens])

特点

  • 最直接的集成方式
  • 可复用标准注意力实现
  • 记忆作为额外上下文参与计算

4.2 Memory as Gated Attention (MAG)

核心思想:通过门控机制控制记忆信息的流动,而非直接拼接

// MAG注意力计算
class MAGAttention {
    // 短期上下文(滑动窗口注意力)
    short_term = SlidingWindowAttention(query, key, value);
    
    // 长期记忆读取
    long_term = memory.read(query);  // O(1) 记忆读取
    
    // 门控融合
    output = gate * long_term + (1 - gate) * short_term;
}

特点

  • 更细粒度的记忆控制
  • 避免记忆token增加序列长度
  • 非线性融合

4.3 Persistent Memory (持久记忆)

核心思想:将某些”通用知识”编码为模型的固有参数

// 持久记忆作为固定偏置
persistent_memory = learned_static_embedding;  // 训练后不更新

// 在注意力中引入持久记忆
output = Attention(query, key, value) + persistent_memory

特点

  • 存储不随推理变化的通用知识
  • 如:常识推理规则、领域无关知识
  • 减少可学习参数的训练负担

5. 与Mamba的核心对比

维度TitansMamba
记忆形式可学习MLP权重固定SSM隐藏状态
遗忘机制自适应权重衰减 输入依赖门控
记忆深度多层MLP(单层循环
测试时学习✅ 在线更新记忆❌ 权重固定
上下文长度200万+ tokens受状态容量限制
时间复杂度

关键差异分析

Mamba的遗忘是”被动”的

  • 状态空间通过输入控制遗忘程度
  • 新信息覆盖旧信息
  • 容量受限于状态维度

Titans的遗忘是”主动”的

  • 通过学习决定遗忘速率
  • 记忆可选择性保留
  • 支持多层次记忆整合

6. 技术实现

6.1 PyTorch实现框架

#include <bits/stdc++.h>
using namespace std;
 
// 神经记忆层
class NeuralMemoryLayer {
public:
    int dim;          // 模型维度
    int memory_depth; // MLP深度
    
    // 可学习参数
    vector<Matrix> memory_weights;  // 记忆MLP权重
    Vector decay_gates;             // 遗忘门控
    Vector momentum_buffer;         // 动量缓冲区
    
    NeuralMemoryLayer(int dim, int depth=2) : dim(dim), memory_depth(depth) {
        // 初始化记忆MLP
        for (int i = 0; i < depth; i++) {
            memory_weights.push_back(random_matrix(dim, dim));
        }
        decay_gates = zeros(dim);
        momentum_buffer = zeros(dim);
    }
    
    // 记忆前向传播(推理时)
    Tensor forward(Tensor x, Tensor memory) {
        // 通过MLP处理输入
        Tensor h = x;
        for (int i = 0; i < memory_depth; i++) {
            h = gelu(memory_weights[i] @ h);
        }
        return h;
    }
    
    // 记忆更新(测试时学习)
    void update_memory(Tensor grad, float lr=1e-3, float momentum=0.9) {
        // 动量更新
        momentum_buffer = momentum * momentum_buffer + grad;
        
        // 更新记忆权重
        for (auto& w : memory_weights) {
            w = w - lr * outer_product(momentum_buffer, grad);
        }
    }
};

6.2 完整Titans块

class TitansBlock {
    NeuralMemoryLayer memory_layer;    // 神经记忆
    int window_size;                  // 滑动窗口大小
    
public:
    Tensor forward(Tensor x, Tensor memory, Tensor mask) {
        // 1. 短期上下文(滑动窗口注意力)
        Tensor short_term = sliding_window_attention(x, window_size, mask);
        
        // 2. 长期记忆读取
        Tensor long_term = memory_layer.forward(x, memory);
        
        // 3. 门控融合
        Tensor gate = sigmoid(linear(x));  // 自适应门控
        Tensor output = gate * long_term + (1 - gate) * short_term;
        
        // 4. 前馈网络
        output = output + feedforward(output);
        
        return output;
    }
};

7. 实验结果

7.1 长上下文基准测试

数据集TitansMambaTransformer-XL
WikiText-10327.01 ppl30.83 ppl38.52 ppl
PG-19SOTA落后落后
BABILong (1M+)显著领先显著落后无法处理

7.2 消融实验

组件贡献度
权重衰减 最大(~60%)
动量 中等(~25%)
卷积预处理较小(~15%)

7.3 上下文长度缩放

  • Titans在2M tokens上下文上仍保持性能
  • Mamba在128K后开始退化
  • 标准Transformer受限于显存

8. 理论分析

8.1 记忆容量

Titans的记忆容量由MLP权重维度决定:

其中 是模型维度。对于 的模型,记忆容量约为 67M参数

8.2 计算权衡

操作复杂度
记忆写入(一次性)
记忆读取(每步)
滑动窗口注意力

总复杂度:,接近线性。

9. 应用场景

9.1 代码仓库理解

  • 跨越数百万行代码的依赖分析
  • 跨文件上下文保持

9.2 长文档分析

  • 法律/金融文档
  • 学术论文理解

9.3 AI Agent持久化

  • 多会话记忆保持
  • 个性化学习

9.4 科学数据

  • 基因组序列分析
  • 时间序列预测

10. 总结

Titans通过引入可学习的神经记忆模块,实现了:

  1. 测试时持续学习:记忆在推理中不断更新
  2. 超长上下文:支持200万+ tokens
  3. 线性推理复杂度
  4. 选择性遗忘:自适应门控机制

这代表了从”上下文窗口”到”持续记忆”的范式转变。


参考资料