概述

Encode-Think-Decode 是一种三阶段推理框架,通过编码(Encode)、思考(Think)、解码(Decode)三个步骤实现高效的测试时推理扩展。该框架的核心创新在于递归隐式思维(Recursive Latent Thoughts),允许模型在隐空间中动态扩展推理深度。1

与显式的 Chain-of-Thought 方法不同,Encode-Think-Decode 在隐空间中进行推理,避免了长序列 token 带来的计算和存储开销。

三阶段架构

阶段 1:编码(Encode)

编码阶段将输入问题转换为紧凑的隐表示:

其中 是原始输入, 是初始编码表示。

class Encoder(nn.Module):
    """
    问题编码器:将输入转换为隐表示
    """
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, x: Tensor) -> Tensor:
        # token嵌入
        h = self.embedding(x)
        # 投影和归一化
        h = self.proj(h)
        return self.norm(h)

阶段 2:思考(Think)

思考阶段是框架的核心,通过递归隐式思维进行推理:

其中 表示递归深度,ThinkLayer 整合当前状态、原始问题和深度信息。

class ThinkLayer(nn.Module):
    """
    思考层:递归隐式推理
    """
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        # 深度条件机制
        self.depth_emb = nn.Parameter(torch.randn(d_model))
        
        # 自状态注意力
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        
        # 交叉状态注意力(关注原始问题)
        self.cross_attn = MultiHeadAttention(d_model, n_heads)
        
        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
    def forward(self, state: Tensor, problem: Tensor, depth: int) -> Tensor:
        # 深度条件
        depth_offset = depth * self.depth_emb
        
        # 自注意力
        h = self.norm1(state + depth_offset)
        h = h + self.self_attn(h, h, h)
        
        # 交叉注意力(回到原始问题)
        h = self.norm2(h)
        h = h + self.cross_attn(h, problem, problem)
        
        # 前馈
        h = self.norm3(h)
        h = h + self.ffn(h)
        
        return h

阶段 3:解码(Decode)

解码阶段将隐式思考结果转换为最终答案:

其中 是最终递归深度。

class Decoder(nn.Module):
    """
    答案解码器:从隐表示生成答案
    """
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.proj = nn.Linear(d_model, d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        
    def forward(self, state: Tensor) -> Tensor:
        # 投影到词汇空间
        logits = self.proj(state)
        return self.lm_head(logits)

完整推理流程

class EncodeThinkDecode:
    """
    完整的编码-思考-解码推理器
    """
    def __init__(self, encoder, thinker, decoder, max_depth: int = 8):
        self.encoder = encoder
        self.thinker = thinker
        self.decoder = decoder
        self.max_depth = max_depth
        
    @torch.no_grad()
    def forward(self, x: Tensor) -> Tensor:
        # 阶段1:编码
        problem = self.encoder(x)
        state = problem
        
        # 阶段2:递归思考
        for depth in range(self.max_depth):
            state = self.thinker(state, problem, depth)
            
            # 可选:早期停止检查
            if self.check_confidence(state):
                break
        
        # 阶段3:解码
        return self.decoder(state)

数学分析

递归深度与计算复杂度

设递归深度为 ,每次思考的计算量为 ,则总计算量为:

对于固定的推理质量目标:

方法计算量上下文长度
标准自回归 tokens
Chain-of-Thought tokens
Encode-Think-Decode vectors

其中 是 CoT 序列长度, 是递归深度。

隐空间容量的影响

隐表示维度 限制了单步推理的信息量:

其中 是量化误差。更大的 允许更复杂的隐式推理。

与其他框架的对比

vs Latent Reasoning

特性Latent ReasoningEncode-Think-Decode
架构单递归块三阶段框架
问题接入初始嵌入每步交叉注意力
深度控制固定自适应
训练稳定性较高更高(分离设计)

vs Chain-of-Thought

特性Chain-of-ThoughtEncode-Think-Decode
推理表示显式 token隐式向量
可解释性中等
计算效率线性超线性
上下文需求增长固定

vs MatryoshkaThinking

特性MatryoshkaThinkingEncode-Think-Decode
核心机制嵌套聚合三阶段分离
递归形式自聚合交叉注意
停止策略置信度可配置
适用场景数学推理通用推理

实验结果

数学推理基准

模型GSM8KMATHAIME
GPT-4 (standard)67.1%42.3%3.3%
GPT-4 + CoT85.2%68.7%9.3%
Encode-Think-Decode ()91.3%79.4%24.7%
Encode-Think-Decode ()92.8%82.1%31.2%

效率对比

方法每样本计算量内存占用延迟
CoT ( steps)100%100%100%
Encode-Think-Decode ()35%28%42%
Encode-Think-Decode ()52%35%58%

实践建议

超参数配置

# 推荐配置
config = {
    'd_model': 4096,        # 隐表示维度
    'n_heads': 32,          # 注意力头数
    'max_depth': 8,        # 最大思考深度
    'depth_emb_scale': 0.1, # 深度嵌入缩放
    'use_early_stop': True, # 启用早期停止
}

训练策略

  1. 课程学习:从浅层递归逐步增加深度
  2. 深度dropout:随机丢弃部分递归层
  3. 问题难度感知:根据问题复杂度调整目标深度

相关工作

参考文献

Footnotes

  1. Anonymous. (2025). Encode, Think, Decode: Scaling test-time reasoning with recursive latent thoughts. arXiv:2510.07358. https://arxiv.org/abs/2510.07358