Titans:测试时记忆学习架构
1. 背景与问题
长期序列建模面临记忆与效率的根本矛盾:
| 架构 | 优势 | 劣势 |
|---|---|---|
| Transformer | 完美回忆、精确依赖建模 | 复杂度,长上下文成本爆炸 |
| RNN/LSTM | 线性复杂度 | 信息压缩丢失,难以捕获长期依赖 |
| Mamba SSM | 线性复杂度+状态空间 | 固定状态容量,信息选择性有限 |
核心问题:如何在保持Transformer建模能力的同时,实现对超长上下文的高效处理?
2. Titans核心思想
Titans将记忆从被动存储转变为主动的可学习组件:
核心洞察:让模型在测试时持续学习”什么值得记忆”
2.1 三层记忆结构
┌─────────────────────────────────────────────────────┐
│ Temporal Memory (时间记忆) │
│ 滑动窗口注意力,处理当前局部上下文 $O(W)$ │
├─────────────────────────────────────────────────────┤
│ Persistent Memory (持久记忆) │
│ 缓慢更新的全局知识,类似世界模型 │
├─────────────────────────────────────────────────────┤
│ Neural Long-Term Memory (神经长期记忆) │
│ 深度MLP模块,测试时在线学习 │
└─────────────────────────────────────────────────────┘
3. 神经记忆模块详解
3.1 记忆矩阵表示
将记忆建模为可学习的矩阵 :
// 记忆的核心数据结构
class NeuralMemory {
Matrix memory; // 可学习参数矩阵 d × d
Vector momentum; // 动量项,用于稳定更新
float decay_rate; // 遗忘率 α_t
// 记忆写入:梯度下降 + 动量
void update(Vector s_t, float lr, float momentum) {
// 计算梯度
grad = compute_gradient(s_t);
// 动量更新
momentum = momentum * momentum + grad;
// 记忆更新:指数衰减旧记忆 + 新信息
memory = (1 - decay_rate) * memory + lr * momentum;
}
}3.2 记忆更新算法
采用指数衰减+动量的记忆更新机制:
其中衰减项 通过带动量的梯度下降计算:
- :自适应遗忘门控
- :动量系数
- :学习率
3.3 门控机制
自适应遗忘门控是记忆效率的关键:
// 门控实现
Tensor gated_update(Tensor residual, Tensor branch_out, Tensor gates) {
// Sigmoid门控控制新信息流入比例
return residual.lerp(branch_out, gates.sigmoid());
}门控机制允许模型选择性遗忘,避免记忆饱和。
4. Titans架构变体
4.1 Memory as Context (MAC)
核心思想:将记忆矩阵展开为”伪token”,拼接在输入序列前
输入序列: [token_1, token_2, ..., token_n]
记忆展开: [mem_1, mem_2, ..., mem_k, token_1, token_2, ..., token_n]
// 通过标准注意力访问记忆
output = MultiHeadAttention(query=[tokens], key=[记忆+tokens], value=[记忆+tokens])
特点:
- 最直接的集成方式
- 可复用标准注意力实现
- 记忆作为额外上下文参与计算
4.2 Memory as Gated Attention (MAG)
核心思想:通过门控机制控制记忆信息的流动,而非直接拼接
// MAG注意力计算
class MAGAttention {
// 短期上下文(滑动窗口注意力)
short_term = SlidingWindowAttention(query, key, value);
// 长期记忆读取
long_term = memory.read(query); // O(1) 记忆读取
// 门控融合
output = gate * long_term + (1 - gate) * short_term;
}特点:
- 更细粒度的记忆控制
- 避免记忆token增加序列长度
- 非线性融合
4.3 Persistent Memory (持久记忆)
核心思想:将某些”通用知识”编码为模型的固有参数
// 持久记忆作为固定偏置
persistent_memory = learned_static_embedding; // 训练后不更新
// 在注意力中引入持久记忆
output = Attention(query, key, value) + persistent_memory
特点:
- 存储不随推理变化的通用知识
- 如:常识推理规则、领域无关知识
- 减少可学习参数的训练负担
5. 与Mamba的核心对比
| 维度 | Titans | Mamba |
|---|---|---|
| 记忆形式 | 可学习MLP权重 | 固定SSM隐藏状态 |
| 遗忘机制 | 自适应权重衰减 | 输入依赖门控 |
| 记忆深度 | 多层MLP() | 单层循环 |
| 测试时学习 | ✅ 在线更新记忆 | ❌ 权重固定 |
| 上下文长度 | 200万+ tokens | 受状态容量限制 |
| 时间复杂度 |
关键差异分析
Mamba的遗忘是”被动”的:
- 状态空间通过输入控制遗忘程度
- 新信息覆盖旧信息
- 容量受限于状态维度
Titans的遗忘是”主动”的:
- 通过学习决定遗忘速率
- 记忆可选择性保留
- 支持多层次记忆整合
6. 技术实现
6.1 PyTorch实现框架
#include <bits/stdc++.h>
using namespace std;
// 神经记忆层
class NeuralMemoryLayer {
public:
int dim; // 模型维度
int memory_depth; // MLP深度
// 可学习参数
vector<Matrix> memory_weights; // 记忆MLP权重
Vector decay_gates; // 遗忘门控
Vector momentum_buffer; // 动量缓冲区
NeuralMemoryLayer(int dim, int depth=2) : dim(dim), memory_depth(depth) {
// 初始化记忆MLP
for (int i = 0; i < depth; i++) {
memory_weights.push_back(random_matrix(dim, dim));
}
decay_gates = zeros(dim);
momentum_buffer = zeros(dim);
}
// 记忆前向传播(推理时)
Tensor forward(Tensor x, Tensor memory) {
// 通过MLP处理输入
Tensor h = x;
for (int i = 0; i < memory_depth; i++) {
h = gelu(memory_weights[i] @ h);
}
return h;
}
// 记忆更新(测试时学习)
void update_memory(Tensor grad, float lr=1e-3, float momentum=0.9) {
// 动量更新
momentum_buffer = momentum * momentum_buffer + grad;
// 更新记忆权重
for (auto& w : memory_weights) {
w = w - lr * outer_product(momentum_buffer, grad);
}
}
};6.2 完整Titans块
class TitansBlock {
NeuralMemoryLayer memory_layer; // 神经记忆
int window_size; // 滑动窗口大小
public:
Tensor forward(Tensor x, Tensor memory, Tensor mask) {
// 1. 短期上下文(滑动窗口注意力)
Tensor short_term = sliding_window_attention(x, window_size, mask);
// 2. 长期记忆读取
Tensor long_term = memory_layer.forward(x, memory);
// 3. 门控融合
Tensor gate = sigmoid(linear(x)); // 自适应门控
Tensor output = gate * long_term + (1 - gate) * short_term;
// 4. 前馈网络
output = output + feedforward(output);
return output;
}
};7. 实验结果
7.1 长上下文基准测试
| 数据集 | Titans | Mamba | Transformer-XL |
|---|---|---|---|
| WikiText-103 | 27.01 ppl | 30.83 ppl | 38.52 ppl |
| PG-19 | SOTA | 落后 | 落后 |
| BABILong (1M+) | 显著领先 | 显著落后 | 无法处理 |
7.2 消融实验
| 组件 | 贡献度 |
|---|---|
| 权重衰减 | 最大(~60%) |
| 动量 | 中等(~25%) |
| 卷积预处理 | 较小(~15%) |
7.3 上下文长度缩放
- Titans在2M tokens上下文上仍保持性能
- Mamba在128K后开始退化
- 标准Transformer受限于显存
8. 理论分析
8.1 记忆容量
Titans的记忆容量由MLP权重维度决定:
其中 是模型维度。对于 的模型,记忆容量约为 67M参数。
8.2 计算权衡
| 操作 | 复杂度 |
|---|---|
| 记忆写入 | (一次性) |
| 记忆读取 | (每步) |
| 滑动窗口注意力 |
总复杂度:,接近线性。
9. 应用场景
9.1 代码仓库理解
- 跨越数百万行代码的依赖分析
- 跨文件上下文保持
9.2 长文档分析
- 法律/金融文档
- 学术论文理解
9.3 AI Agent持久化
- 多会话记忆保持
- 个性化学习
9.4 科学数据
- 基因组序列分析
- 时间序列预测
10. 总结
Titans通过引入可学习的神经记忆模块,实现了:
- ✅ 测试时持续学习:记忆在推理中不断更新
- ✅ 超长上下文:支持200万+ tokens
- ✅ 线性推理复杂度:
- ✅ 选择性遗忘:自适应门控机制
这代表了从”上下文窗口”到”持续记忆”的范式转变。