概述
Encode-Think-Decode 是一种三阶段推理框架,通过编码(Encode)、思考(Think)、解码(Decode)三个步骤实现高效的测试时推理扩展。该框架的核心创新在于递归隐式思维(Recursive Latent Thoughts),允许模型在隐空间中动态扩展推理深度。1
与显式的 Chain-of-Thought 方法不同,Encode-Think-Decode 在隐空间中进行推理,避免了长序列 token 带来的计算和存储开销。
三阶段架构
阶段 1:编码(Encode)
编码阶段将输入问题转换为紧凑的隐表示:
其中 是原始输入, 是初始编码表示。
class Encoder(nn.Module):
"""
问题编码器:将输入转换为隐表示
"""
def __init__(self, d_model: int, vocab_size: int):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.proj = nn.Linear(d_model, d_model)
self.norm = nn.LayerNorm(d_model)
def forward(self, x: Tensor) -> Tensor:
# token嵌入
h = self.embedding(x)
# 投影和归一化
h = self.proj(h)
return self.norm(h)阶段 2:思考(Think)
思考阶段是框架的核心,通过递归隐式思维进行推理:
其中 表示递归深度,ThinkLayer 整合当前状态、原始问题和深度信息。
class ThinkLayer(nn.Module):
"""
思考层:递归隐式推理
"""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
# 深度条件机制
self.depth_emb = nn.Parameter(torch.randn(d_model))
# 自状态注意力
self.self_attn = MultiHeadAttention(d_model, n_heads)
# 交叉状态注意力(关注原始问题)
self.cross_attn = MultiHeadAttention(d_model, n_heads)
# 前馈网络
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, state: Tensor, problem: Tensor, depth: int) -> Tensor:
# 深度条件
depth_offset = depth * self.depth_emb
# 自注意力
h = self.norm1(state + depth_offset)
h = h + self.self_attn(h, h, h)
# 交叉注意力(回到原始问题)
h = self.norm2(h)
h = h + self.cross_attn(h, problem, problem)
# 前馈
h = self.norm3(h)
h = h + self.ffn(h)
return h阶段 3:解码(Decode)
解码阶段将隐式思考结果转换为最终答案:
其中 是最终递归深度。
class Decoder(nn.Module):
"""
答案解码器:从隐表示生成答案
"""
def __init__(self, d_model: int, vocab_size: int):
super().__init__()
self.proj = nn.Linear(d_model, d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, state: Tensor) -> Tensor:
# 投影到词汇空间
logits = self.proj(state)
return self.lm_head(logits)完整推理流程
class EncodeThinkDecode:
"""
完整的编码-思考-解码推理器
"""
def __init__(self, encoder, thinker, decoder, max_depth: int = 8):
self.encoder = encoder
self.thinker = thinker
self.decoder = decoder
self.max_depth = max_depth
@torch.no_grad()
def forward(self, x: Tensor) -> Tensor:
# 阶段1:编码
problem = self.encoder(x)
state = problem
# 阶段2:递归思考
for depth in range(self.max_depth):
state = self.thinker(state, problem, depth)
# 可选:早期停止检查
if self.check_confidence(state):
break
# 阶段3:解码
return self.decoder(state)数学分析
递归深度与计算复杂度
设递归深度为 ,每次思考的计算量为 ,则总计算量为:
对于固定的推理质量目标:
| 方法 | 计算量 | 上下文长度 |
|---|---|---|
| 标准自回归 | tokens | |
| Chain-of-Thought | tokens | |
| Encode-Think-Decode | vectors |
其中 是 CoT 序列长度, 是递归深度。
隐空间容量的影响
隐表示维度 限制了单步推理的信息量:
其中 是量化误差。更大的 允许更复杂的隐式推理。
与其他框架的对比
vs Latent Reasoning
| 特性 | Latent Reasoning | Encode-Think-Decode |
|---|---|---|
| 架构 | 单递归块 | 三阶段框架 |
| 问题接入 | 初始嵌入 | 每步交叉注意力 |
| 深度控制 | 固定 | 自适应 |
| 训练稳定性 | 较高 | 更高(分离设计) |
vs Chain-of-Thought
| 特性 | Chain-of-Thought | Encode-Think-Decode |
|---|---|---|
| 推理表示 | 显式 token | 隐式向量 |
| 可解释性 | 高 | 中等 |
| 计算效率 | 线性 | 超线性 |
| 上下文需求 | 增长 | 固定 |
vs MatryoshkaThinking
| 特性 | MatryoshkaThinking | Encode-Think-Decode |
|---|---|---|
| 核心机制 | 嵌套聚合 | 三阶段分离 |
| 递归形式 | 自聚合 | 交叉注意 |
| 停止策略 | 置信度 | 可配置 |
| 适用场景 | 数学推理 | 通用推理 |
实验结果
数学推理基准
| 模型 | GSM8K | MATH | AIME |
|---|---|---|---|
| GPT-4 (standard) | 67.1% | 42.3% | 3.3% |
| GPT-4 + CoT | 85.2% | 68.7% | 9.3% |
| Encode-Think-Decode () | 91.3% | 79.4% | 24.7% |
| Encode-Think-Decode () | 92.8% | 82.1% | 31.2% |
效率对比
| 方法 | 每样本计算量 | 内存占用 | 延迟 |
|---|---|---|---|
| CoT ( steps) | 100% | 100% | 100% |
| Encode-Think-Decode () | 35% | 28% | 42% |
| Encode-Think-Decode () | 52% | 35% | 58% |
实践建议
超参数配置
# 推荐配置
config = {
'd_model': 4096, # 隐表示维度
'n_heads': 32, # 注意力头数
'max_depth': 8, # 最大思考深度
'depth_emb_scale': 0.1, # 深度嵌入缩放
'use_early_stop': True, # 启用早期停止
}训练策略
- 课程学习:从浅层递归逐步增加深度
- 深度dropout:随机丢弃部分递归层
- 问题难度感知:根据问题复杂度调整目标深度
相关工作
- 隐式推理:Latent Reasoning 架构
- MatryoshkaThinking:嵌套递归推理
- 链式推理:显式 token 推理
- 推测推理:并行推理加速
参考文献
Footnotes
-
Anonymous. (2025). Encode, Think, Decode: Scaling test-time reasoning with recursive latent thoughts. arXiv:2510.07358. https://arxiv.org/abs/2510.07358 ↩