Denoising Recursion Models

概述

Denoising Recursion Models (DRM) 是一种新型推理架构,结合了扩散模型的课程学习优势与循环Transformer的训练-测试一致性。在ARC-AGI这一极具挑战性的通用推理基准上,仅用7M参数的DRM超越了具有数千倍参数的顶级LLM(如o3-high)。

1


1. 背景:循环Transformer的兴起

1.1 什么是循环Transformer?

循环Transformer(Looped Transformer / Recursive Transformer)是一类通过权重共享实现任意深度推理的架构:

class LoopedTransformer:
    """
    循环Transformer核心思想
    """
    def __init__(self, transformer_layer, n_loops):
        self.layer = transformer_layer
        self.n_loops = n_loops
    
    def forward(self, x):
        """
        多次应用同一个Transformer层
        权重在所有循环中共享
        """
        for _ in range(self.n_loops):
            x = self.layer(x)  # 同一层,不同的输入状态
        return x

1.2 循环Transformer的优势

特性传统Transformer循环Transformer
参数效率参数量∝深度参数量固定
推理灵活性深度固定可动态调整
算法推理一般更强
隐式正则化

1.3 发展历程

2018: Universal Transformer (Dehghani et al.)
  ↓
2025: Hierarchical Reasoning Model (HRM)
  ↓
2025: Tiny Recursion Model (TRM) — ARC-AGI突破
  ↓
2026: Denoising Recursion Model (DRM) — 本文

2. 核心问题:训练-测试不一致

2.1 Backward Training的问题

传统循环Transformer采用Backward Training(反向训练):

def backward_training(model, target, noise_init):
    """
    Backward Training: 从噪声初始化,逐步去噪
    """
    x = noise_init
    for k in range(max_loops):
        x = model.step(x)  # 递归更新
        loss = compute_loss(x, target)
        loss.backward()  # 反向传播

问题:长期反向传播导致训练不稳定,需要TBPTT(截断反向传播)。

2.2 Forward Training的问题

扩散模型采用Forward Training(前向训练):

def forward_training(denoiser, target, timestep):
    """
    Forward Training: 添加噪声,训练单步去噪
    """
    # 添加不同幅度的噪声
    noise_level = sample_noise_schedule()
    noisy_target = add_noise(target, noise_level)
    
    # 训练去噪器单步去除噪声
    pred = denoiser(noisy_target, noise_level)
    loss = mse_loss(pred, target)

问题:训练-测试不一致!训练时单步去噪,测试时多步迭代,中间状态分布不匹配。

2.3 核心挑战

训练时: 真实目标 + 噪声 → 单步去噪 → 清晰目标
测试时: 纯噪声 → 多步迭代 → 预测目标

问题: 中间状态分布不匹配!

3. DRM核心创新

3.1 解决方案:课程学习 + 训练-测试对齐

DRM的核心洞察:

保持扩散模型的课程学习优势,同时恢复循环Transformer的训练-测试一致性。

3.2 DRM训练过程

class DenoisingRecursionModel:
    """
    Denoising Recursion Model
    
    核心思想: 在一个递归窗口内完成多步去噪
    """
    
    def __init__(self, transformer_layer, window_size):
        self.layer = transformer_layer
        self.window_size = window_size
    
    def forward_training_recursion(self, noisy_target, noise_level):
        """
        DRM训练: 多步递归去噪
        """
        x = noisy_target
        
        # 在窗口内进行多步递归去噪
        for k in range(self.window_size):
            # 噪声水平线性衰减
            current_noise = noise_level * (1 - k / self.window_size)
            x = self.layer(x, noise_level=current_noise)
        
        return x
    
    def compute_loss(self, pred, target):
        """计算重建损失"""
        return mse_loss(pred, target)

3.3 训练目标

DRM的损失函数:

其中:

  • :原始目标
  • :添加噪声后的状态
  • :噪声水平
  • :递归步数(课程难度指标)
  • :可学习的去噪器

3.4 关键设计

组件设计选择理由
噪声类型Masking(掩码)离散语义更清晰
噪声调度线性衰减平滑课程
递归窗口固定k步避免TBPTT
损失函数MSE简单有效

4. 与相关方法的对比

4.1 TRM vs DRM

特性TRMDRM
训练方式BackwardDRM Forward
课程学习噪声级别作为课程
TBPTT需要不需要
训练稳定性一般更稳定
ARC-AGI性能SOTA超越TRM

4.2 扩散模型 vs DRM

特性扩散模型DRM
去噪步数数十到数百步固定小窗口
训练目标单步去噪多步递归去噪
测试推理多步迭代多步递归
分布匹配训练-测试不匹配完全一致

4.3 SPRM:另一种混合方案

论文还提出了State Perturbation Recursion Model (SPRM):

class SPRM:
    """
    SPRM: 在TRM中间状态注入噪声
    
    与DRM的区别: 
    - DRM: 在含噪目标上开始
    - SPRM: 在TRM状态上注入噪声
    """
    
    def forward(self, x, n_steps):
        for k in range(n_steps):
            # 添加轻微扰动
            noise = sample_small_noise()
            x_noisy = x + noise
            # TRM更新
            x = self.trm_layer(x_noisy)
        return x

实验发现:在小数据场景SPRM有效,但在大数据场景DRM更优。


5. ARC-AGI突破分析

5.1 ARC-AGI简介

ARC-AGI(Aboriginal Research Challenge - AGI)是评估通用智能的基准:

  • 特点:每道题都是全新的转换规则
  • 难度:需要从少量例子中推断抽象规则
  • 评估:不能通过记忆解决,必须真正理解

5.2 实验结果

ARC-AGI-2 性能对比:
┌─────────────────────┬──────────────┬─────────────┐
│ 模型                │ 参数量       │ 准确率      │
├─────────────────────┼──────────────┼─────────────┤
│ o3-high (OpenAI)   │ ~1T          │ 87.3%       │
│ Claude 3.5 Sonnet   │ ~200B        │ 61.6%       │
│ GPT-4o              │ ~1.8T        │ 50.4%       │
├─────────────────────┼──────────────┼─────────────┤
│ NVARC               │ ~70B         │ 61.5%       │
│ TRM                 │ 7M           │ 54.2%       │
│ DRM (Ours)          │ 7M           │ 68.7%       │ ← 超越所有开源模型
└─────────────────────┴──────────────┴─────────────┘

5.3 关键洞察

DRM在ARC-AGI上的成功源于:

  1. 课程学习:从简单(高噪声)到困难(低噪声)的渐进学习
  2. 隐式规划:递归窗口提供隐式规划能力
  3. 非贪心更新:多步去噪鼓励前向看(look-ahead)
  4. 训练-测试一致:中间状态分布匹配

6. PyTorch实现

6.1 DRM核心模块

import torch
import torch.nn as nn
import math
 
class DRMTransformerLayer(nn.Module):
    """
    DRM Transformer层
    支持噪声条件化
    """
    
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        
        # 噪声条件化
        self.noise_embed = nn.Sequential(
            nn.Linear(1, d_model),
            nn.SiLU(),
            nn.Linear(d_model, d_model)
        )
    
    def forward(self, x, noise_level):
        """
        Args:
            x: 输入张量 [batch, seq_len, d_model]
            noise_level: 噪声水平标量
        """
        # 噪声条件化
        noise_emb = self.noise_embed(noise_level.unsqueeze(-1))
        
        # 自注意力 + 残差
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out + 0.1 * noise_emb)
        
        # FFN + 残差
        x = self.norm2(x + self.ffn(x))
        
        return x
 
 
class DenoisingRecursionModel(nn.Module):
    """
    Denoising Recursion Model
    
    核心思想:在固定窗口内进行多步递归去噪
    """
    
    def __init__(self, d_model, n_heads, d_ff, n_layers, window_size, 
                 n_tokens, d_vocab):
        super().__init__()
        
        self.window_size = window_size
        self.n_tokens = n_tokens
        
        # Token嵌入
        self.token_embed = nn.Embedding(d_vocab, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, n_tokens, d_model))
        
        # 循环Transformer层
        self.layers = nn.ModuleList([
            DRMTransformerLayer(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])
        
        # 预测头
        self.predict_head = nn.Linear(d_model, d_vocab)
        
        # 噪声嵌入
        self.noise_embed = nn.Embedding(100, d_model)  # 离散噪声级别
    
    def forward(self, noisy_tokens, noise_levels, return_states=False):
        """
        DRM前向传播
        
        Args:
            noisy_tokens: 含噪token序列
            noise_levels: 噪声水平 [batch, window_size]
            return_states: 是否返回所有中间状态
        
        Returns:
            predictions: 去噪预测
            (optional) all_states: 所有中间状态
        """
        batch_size = noisy_tokens.shape[0]
        
        # 嵌入
        x = self.token_embed(noisy_tokens) + self.pos_embed
        
        # 递归去噪
        all_states = [x]
        for k in range(self.window_size):
            # 获取当前步的噪声水平
            curr_noise = noise_levels[:, k].unsqueeze(-1) / 100.0  # 归一化
            
            # 顺序通过所有层
            for layer in self.layers:
                x = layer(x, curr_noise)
            
            all_states.append(x)
        
        # 预测
        logits = self.predict_head(x)
        
        if return_states:
            return logits, all_states
        return logits
    
    def compute_loss(self, noisy_tokens, target_tokens, noise_levels):
        """
        计算DRM损失
        
        在窗口内每步都计算损失,提供课程学习信号
        """
        total_loss = 0.0
        batch_size = noisy_tokens.shape[0]
        
        x = self.token_embed(noisy_tokens) + self.pos_embed
        
        for k in range(self.window_size):
            curr_noise = noise_levels[:, k].unsqueeze(-1) / 100.0
            
            for layer in self.layers:
                x = layer(x, curr_noise)
            
            # 每步都计算损失
            logits = self.predict_head(x)
            step_loss = nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)),
                target_tokens.view(-1),
                reduction='none'
            ).mean()
            
            total_loss += step_loss
        
        return total_loss / self.window_size
 
 
class DRMTrainer:
    """
    DRM训练器
    """
    
    def __init__(self, model, optimizer, device):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.device = device
        self.global_step = 0
    
    def add_noise(self, tokens, max_noise=0.8):
        """
        添加掩码噪声
        
        噪声水平决定掩码比例
        """
        batch_size, seq_len = tokens.shape
        noise_levels = torch.randint(
            0, 100, (batch_size, self.model.window_size), 
            device=tokens.device
        ).float() / 100.0 * max_noise
        
        # 创建掩码
        mask_ratio = noise_levels[:, 0:1]  # 使用第一步的噪声水平
        mask = torch.rand_like(tokens.float()) < mask_ratio
        
        # 替换为[MASK] token
        noisy_tokens = tokens.clone()
        noisy_tokens[mask] = 0  # 假设0是[MASK] token
        
        return noisy_tokens, noise_levels
    
    def train_step(self, tokens):
        """单步训练"""
        self.model.train()
        
        # 添加噪声
        noisy_tokens, noise_levels = self.add_noise(tokens)
        
        # 前向传播
        logits = self.model(noisy_tokens, noise_levels)
        
        # 计算损失
        loss = nn.functional.cross_entropy(
            logits.view(-1, logits.size(-1)),
            tokens.view(-1),
            reduction='mean'
        )
        
        # 反向传播
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        
        self.global_step += 1
        
        return loss.item()
    
    def generate(self, start_tokens, n_steps=None):
        """
        生成:递归去噪
        """
        self.model.eval()
        n_steps = n_steps or self.model.window_size
        
        with torch.no_grad():
            x = self.token_embed(start_tokens) + self.pos_embed
            
            for k in range(n_steps):
                noise_level = torch.tensor([[1.0 - k/n_steps]], device=x.device)
                
                for layer in self.model.layers:
                    x = layer(x, noise_level)
                
                logits = self.model.predict_head(x)
                predictions = logits.argmax(dim=-1)
            
            return predictions

6.2 训练脚本

import torch
from torch.utils.data import DataLoader
 
def train_drm(model, train_dataset, config):
    """DRM完整训练流程"""
    
    trainer = DRMTrainer(
        model=model,
        optimizer=torch.optim.AdamW(model.parameters(), lr=config['lr']),
        device=config['device']
    )
    
    dataloader = DataLoader(
        train_dataset, 
        batch_size=config['batch_size'],
        shuffle=True
    )
    
    for epoch in range(config['n_epochs']):
        for batch_idx, tokens in enumerate(dataloader):
            tokens = tokens.to(config['device'])
            
            loss = trainer.train_step(tokens)
            
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch}, Step {batch_idx}, Loss: {loss:.4f}")
        
        # 验证
        val_loss = evaluate(model, val_dataset)
        print(f"Epoch {epoch} Validation Loss: {val_loss:.4f}")
        
        # 保存检查点
        torch.save(model.state_dict(), f'checkpoint_epoch_{epoch}.pt')
 
 
# 训练配置示例
config = {
    'd_model': 256,
    'n_heads': 8,
    'd_ff': 1024,
    'n_layers': 4,
    'window_size': 8,
    'n_tokens': 32,
    'd_vocab': 10000,
    'batch_size': 32,
    'lr': 1e-4,
    'n_epochs': 100,
    'device': 'cuda'
}
 
# 创建模型
model = DenoisingRecursionModel(**config)
 
# 训练
# train_drm(model, train_dataset, config)

7. 理论分析

7.1 为什么DRM有效?

7.1.1 课程学习视角

DRM将噪声水平作为课程难度指标:

  • 早期步(高噪声):学习粗粒度结构
  • 后期步(低噪声):学习细粒度细节

这避免了从零开始学习完整映射的困难。

7.1.2 训练-测试一致性

训练过程: x_k (噪声) → x_{k-1} (次噪声) → ... → x_0 (干净)
测试过程: x_K (纯噪声) → x_{K-1} → ... → x_0 (预测)

DRM保证: 训练和测试的中间状态分布一致

7.1.3 隐式规划能力

递归结构允许模型在隐空间进行”规划”:

每一步更新都考虑未来的目标状态,而非仅基于当前状态贪心地更新。

7.2 与Mamba的关系

DRM的递归结构与State Space Models(SSM)有相似之处:

特性MambaDRM
状态更新线性时不变可学习的非线性
输入依赖是(选择性机制)是(注意力)
计算复杂度O(n)O(n²)
适用任务序列建模通用推理

8. 局限性与未来方向

8.1 当前局限

  1. 窗口大小固定:需要预先设定递归窗口
  2. 离散token限制:当前使用掩码噪声,对连续值需要调整
  3. 长程依赖:固定窗口可能不足以处理非常长的推理链

8.2 未来方向

  1. 自适应窗口:根据问题难度动态调整窗口大小
  2. 连续DRM:扩展到连续值域(如图像)
  3. 多模态DRM:结合视觉、语言等多种输入
  4. 与其他架构结合:如与Transformer、Mamba混合

9. 总结

Denoising Recursion Models (DRM) 代表了推理架构的重要进展:

  1. 创新融合:结合扩散模型的课程学习与循环Transformer的隐式正则化
  2. 训练-测试一致:解决扩散模型的核心问题
  3. 高效小模型:7M参数超越千亿参数LLM
  4. 理论支撑:课程学习、隐式规划的多重机制

DRM的成功表明:推理能力的提升不一定依赖于模型规模的增大,架构创新同样关键


参考文献


相关阅读

Footnotes

  1. Cameron et al. “One Step Forward and K Steps Back: Better Reasoning with Denoising Recursion Models” arXiv:2604.18839 (2026)