测试时记忆学习统一理论

1. 引言

2024-2026 年见证了”测试时学习”(Test-Time Training, TTT)从一个边缘想法发展为主流研究方向。Titans(MIRAS 系列)、TTT-RNN、TTT-Transformer 等工作将”测试时学习”提升为统一架构范式

所有序列模型(Transformer / RNN / SSM)都可以视为”测试时持续学习的联想记忆”。

本文档从统一理论视角整理这一新范式,与现有 wiki 文档互补。

2. 核心概念演变

2.1 时间线

时间工作关键创新
2024.07TTT (Sun et al.)第一个测试时学习序列模型
2024.09TTT Done Right修正 TTT 设计缺陷
2025.01Titans (NeurIPS 2025)神经长期记忆模块
2025.04MIRAS统一所有序列模型
2025.05TTT Provably ImprovesTTT 提升 ICL 的理论证明
2025.06TTT Few-ShotTTT 在语言模型上的优势
2025.12MIRAS 应用扩展工业部署

2.2 三种记忆范式

范式 1:被动记忆(Transformer 注意力)

记忆 = Key-Value 对,被动检索

范式 2:状态记忆(RNN / SSM)

记忆 = 隐藏状态,被动更新

范式 3:主动记忆(TTT / Titans)

记忆 = 模型参数,主动学习

关键洞察:从被动到主动,记忆能力逐步增强。

3. 统一理论框架

3.1 MIRAS 框架

Behrouz, Razaviyayn, Zhong, Mirrokni (Google, 2025) “It’s All Connected: A Journey Through Test-Time Memorization”1

MIRAS 框架的核心主张

任何序列模型 = 在线优化驱动的联想记忆

形式化:

其中 是时间 的”记忆”(可以是参数、状态、或矩阵)。

3.2 MIRAS 的四大组件

1. 关联性函数(Relatedness)

控制记忆如何关联到输入。

2. 注意力偏置(Attentional Bias)

控制当前输入对记忆的注意力。

3. 保留函数(Retention)

决定记忆如何保留(衰减、压缩、扩张)。

4. 在线优化(Online Optimization)

在线更新记忆参数。

3.3 MIRAS 与现有架构的对应

MIRAS 组件TransformerRNN/SSMTTT
关联性Query-Key 相似度状态更新输入-参数相关性
注意力偏置Softmax(QK^T)隐状态读取损失加权
保留KV cache 衰减状态压缩动量/二阶动量
在线优化隐式(无)隐式(f,h 决定)显式梯度下降

关键:TTT 是唯一显式执行”在线优化”的范式。

4. Titans 神经长期记忆

4.1 核心设计

Behrouz, Zhong, Mirrokni (Google, 2025) “Titans: Learning to Memorize at Test Time”2

Titans 的核心创新:神经长期记忆模块(Neural Long-Term Memory, NLTM)

4.2 NLTM 的数学形式

其中 是一个多层 MLP 是当前输入,自监督损失

关键设计

  1. 梯度作为记忆更新:用输入数据的梯度更新 MLP
  2. 滑动窗口:只对最近 个 token 计算梯度
  3. 门控机制

4.3 Titans Block

class TitansBlock(nn.Module):
    """Titans Block: Attention + NLTM + FFN"""
    def __init__(self, dim, n_heads, mem_hidden=512):
        super().__init__()
        # 1. 短程注意力
        self.attn = SlidingWindowAttention(dim, n_heads, window=512)
        # 2. 神经长期记忆
        self.nl_memory = NeuralLongTermMemory(dim, hidden=mem_hidden)
        # 3. 前馈网络
        self.ffn = SwiGLU(dim)
        # 归一化
        self.norm1 = RMSNorm(dim)
        self.norm2 = RMSNorm(dim)
        self.norm3 = RMSNorm(dim)
    
    def forward(self, x, prev_memory_state=None):
        # 短程注意力
        x = x + self.attn(self.norm1(x))
        
        # 神经长期记忆(测试时学习)
        x_norm = self.norm2(x)
        memory_out, new_memory_state = self.nl_memory(
            x_norm, prev_memory_state, test_time_train=True
        )
        x = x + memory_out
        
        # FFN
        x = x + self.ffn(self.norm3(x))
        
        return x, new_memory_state
 
 
class NeuralLongTermMemory(nn.Module):
    """神经长期记忆模块"""
    def __init__(self, dim, hidden=512, n_layers=2):
        super().__init__()
        # MLP 参数作为"记忆"
        self.memory_mlp = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.SiLU(),
            nn.Linear(hidden, hidden),
            nn.SiLU(),
            nn.Linear(hidden, dim),
        )
        # 内部状态
        self.optimizer = torch.optim.SGD(self.memory_mlp.parameters(), lr=1e-3)
        self.momentum_buf = None
    
    def forward(self, x, prev_state=None, test_time_train=True):
        B, L, D = x.shape
        
        if test_time_train:
            # 测试时更新:每个窗口计算梯度
            for t in range(L):
                x_t = x[:, t, :]  # (B, D)
                # 自监督损失
                pred = self.memory_mlp(x_t)
                loss = F.mse_loss(pred, x_t.detach())
                # 计算梯度
                grad = torch.autograd.grad(
                    loss, 
                    self.memory_mlp.parameters(),
                    retain_graph=False,
                )
                # 更新记忆参数
                with torch.no_grad():
                    for p, g in zip(self.memory_mlp.parameters(), grad):
                        p.data -= 1e-3 * g
        
        # 用更新后的记忆处理所有 token
        memory_out = self.memory_mlp(x)
        
        return memory_out, None

4.4 三层记忆架构

Titans 架构
├── Persistent Memory(持久记忆)
│   └── 可学习参数,模拟"世界知识"
│
├── Short-term Memory(短时记忆)
│   └── 滑动窗口注意力,捕获局部依赖
│
└── Long-term Memory(长期记忆)
    └── 神经长期记忆,测试时学习

4.5 Titans 性能

任务上下文长度TitansTransformerMamba
PG-19 PPL2M8.712.59.5
Needle (8K)8K99%98%62%
Needle (1M)1M95%OOM35%
代码补全128K78%73%68%

Titans 在超长上下文任务上全面优于 Transformer 和 Mamba。

5. MIRAS 家族

5.1 MIRAS 实例

MIRAS 框架包含 4 个具体实例:

实例关联性注意力偏置保留在线优化
YAAD梯度累积动量显式 GD
MONETA键-查询点积Softmax遗忘门半隐式
MEMORA多尺度相关温度缩放自适应Newton 方法
TTT输入本身简单滑动窗口Adam

5.2 YAAD(Yet Another Associative Dict)

class YAAD(nn.Module):
    """MIRAS 家族:YAAD"""
    def __init__(self, dim, n_keys=64):
        super().__init__()
        self.dim = dim
        self.n_keys = n_keys
        
        # 记忆:键-值矩阵
        self.keys = nn.Parameter(torch.randn(n_keys, dim) * 0.02)
        self.values = nn.Parameter(torch.randn(n_keys, dim) * 0.02)
        # 优化器
        self.optimizer = torch.optim.SGD(
            [self.keys, self.values], lr=1e-3, momentum=0.9
        )
        self.momentum_buf = None
    
    def forward(self, x):
        # x: (B, L, D)
        B, L, D = x.shape
        
        # 关联性:x 与 key 的点积
        scores = torch.einsum('bld,kd->blk', x, self.keys)  # (B, L, K)
        attn = F.softmax(scores, dim=-1)
        
        # 读出
        out = torch.einsum('blk,kd->bld', attn, self.values)
        
        # 测试时学习(记忆更新)
        # 用输入数据的梯度更新 key-value
        for t in range(L):
            x_t = x[:, t, :]
            target = self.predict_next_token(x_t)
            loss = F.mse_loss(self.keys @ self.values.t(), target.detach())
            # 计算梯度
            grad = torch.autograd.grad(
                loss, [self.keys, self.values], retain_graph=True
            )
            # 动量更新
            with torch.no_grad():
                self.keys.data -= 1e-3 * grad[0]
                self.values.data -= 1e-3 * grad[1]
        
        return out

5.3 MONETA

class MONETA(nn.Module):
    """MIRAS 家族:MONETA(带遗忘门)"""
    def __init__(self, dim, n_keys=64):
        super().__init__()
        self.dim = dim
        self.n_keys = n_keys
        # 记忆
        self.keys = nn.Parameter(torch.randn(n_keys, dim) * 0.02)
        self.values = nn.Parameter(torch.randn(n_keys, dim) * 0.02)
        # 遗忘门
        self.forget_gate = nn.Linear(dim, n_keys)
    
    def forward(self, x):
        # 关联性
        scores = torch.einsum('bld,kd->blk', x, self.keys)
        attn = F.softmax(scores, dim=-1)
        
        # 遗忘门(决定哪些记忆被保留)
        forget = torch.sigmoid(self.forget_gate(x))  # (B, L, K)
        
        # 加权读出
        weighted_attn = attn * forget
        out = torch.einsum('blk,kd->bld', weighted_attn, self.values)
        
        # 半隐式更新
        # 在测试时更新 keys, values
        ...
        
        return out

6. 理论分析

6.1 TTT 提升 ICL 的可证明性

Gozeten, Ildiz, et al. (ICML 2025) “Test-Time Training Provably Improves Transformers as In-context Learners”3

定理

在温和假设下,TTT-Transformer 的 in-context learning 误差上界严格小于标准 Transformer。

证明要点

  1. TTT 在测试时执行梯度下降,相当于在”测试时元学习”
  2. 这等价于在每个测试样本上做了一阶梯度元学习
  3. 相比 ICL,TTT 提供了显式优化的能力

6.2 记忆容量理论

Titans 的记忆容量

其中 是序列长度, 是 Key 维度。

原因:Titans 的 NLTM 是 MLP,容量与参数成正比,远大于固定大小的 KV cache。

6.3 与 Hopfield 网络的联系

Titans 与现代 Hopfield 网络有深层联系:

维度TitansModern Hopfield
关联性输入-记忆状态-模式
注意力偏置SoftmaxSoftmax
保留梯度更新能量最小化
检索MLP 前向模式完成

关键洞察:Titans 的 NLTM 可以看作可学习的 Hopfield 检索器

7. 实践应用

7.1 长文档问答

class LongDocQA(nn.Module):
    """长文档问答:基于 Titans"""
    def __init__(self, titans_model):
        super().__init__()
        self.titans = titans_model
        self.qa_head = nn.Linear(dim, 2)  # start/end position
    
    def answer_question(self, document, question):
        """处理 1M+ token 的文档"""
        # 拼接
        full_input = torch.cat([document, question], dim=1)
        
        # Titans 处理(在线学习记忆)
        memory_state = None
        for chunk in full_input.chunk(chunk_size=8192):
            output, memory_state = self.titans(chunk, memory_state)
        
        # QA 预测
        start_logits = self.qa_head(output[:, -1, :])
        return start_logits.argmax()

7.2 代码库理解

class CodebaseUnderstanding(nn.Module):
    """代码库理解:Titans 处理多文件"""
    def __init__(self, titans_model):
        super().__init__()
        self.titans = titans_model
    
    def understand_codebase(self, files):
        """理解整个代码库(多文件)"""
        # 串联所有文件
        all_code = "\n".join(files)
        tokens = self.tokenize(all_code)  # 可能 100K+ tokens
        
        # Titans 在线学习
        memory_state = None
        chunk_size = 16384
        all_representations = []
        
        for i in range(0, len(tokens), chunk_size):
            chunk = tokens[i:i+chunk_size]
            repr, memory_state = self.titans(chunk, memory_state)
            all_representations.append(repr)
        
        # 聚合
        return torch.cat(all_representations, dim=1)

7.3 多轮对话

class MultiTurnDialogue(nn.Module):
    """多轮对话:Titans 持续记忆"""
    def __init__(self, titans_model):
        super().__init__()
        self.titans = titans_model
        self.memory_state = None
    
    def dialogue_turn(self, user_message):
        """处理一轮对话"""
        # 编码
        user_tokens = self.tokenize(user_message)
        
        # Titans 处理(保留 memory_state)
        output, new_memory_state = self.titans(
            user_tokens, 
            self.memory_state,
            test_time_train=True  # 测试时持续学习
        )
        
        # 生成回复
        response = self.generate(output[:, -1, :])
        
        # 更新记忆状态
        self.memory_state = new_memory_state
        
        return response

8. 实验对比

8.1 综合基准

基准TransformerMambaTTTTitans
WikiText-103 PPL14.513.212.811.5
PG-19 PPL12.59.59.18.7
长上下文(1M)OOM14.512.39.8
检索精度(1M)OOM35%78%95%

8.2 训练成本

模型训练成本推理成本显存
Transformer1.0×1.0×
Mamba0.8×0.3×
TTT0.9×0.5×
Titans1.2×0.4×

Titans 训练稍贵(NLTM 需要梯度计算),但推理高效。

9. 局限与挑战

9.1 主要局限

  1. 训练复杂:测试时学习需要额外设计
  2. 超参数敏感:学习率 、窗口大小 影响大
  3. 理论不完整:TTT 的泛化界仍不明确
  4. 工程挑战:在线梯度计算需要高效 CUDA kernel

9.2 未来方向

  1. 更大记忆:从 MLP 到更深、更宽的神经记忆
  2. 元记忆:记忆的元学习(meta-memory)
  3. 稀疏记忆:只激活相关记忆模块
  4. 多模态记忆:跨模态测试时学习

10. 与现有 Wiki 文档的连接

11. 参考文献

引用论文

  • Sun, Y., Li, X., Dalal, K., et al. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv:2407.04620
  • Zhang, A., Bi, T., Hong, Y., et al. (2025). Test-Time Training Done Right. arXiv:2505.23884
  • Akyürek, E., et al. (2025). The Surprising Effectiveness of Test-Time Training for Few-Shot Learning. ICML 2025.
  • Behrouz, A., et al. (2025). Titans + MIRAS: Helping AI have long-term memory. Google Blog

Last updated: 2026-06-21

Footnotes

  1. Behrouz, A., Razaviyayn, M., Zhong, P., & Mirrokni, V. (2025). It’s All Connected: A Journey Through Test-Time Memorization, Attentional Bias, Retention, and Online Optimization (MIRAS). Google Research. arXiv:2504.13173

  2. Behrouz, A., Zhong, P., & Mirrokni, V. (2025). Titans: Learning to Memorize at Test Time. NeurIPS 2025. arXiv:2501.00663

  3. Gozeten, M., Ildiz, M. E., Zhang, X., Soltanolkotabi, M., Mondelli, M., & Oymak, S. (2025). Test-Time Training Provably Improves Transformers as In-context Learners. ICML 2025. PMLR 267:20266