1. 引言
大语言推理模型(Large Reasoning Models,LRMs)通过生成扩展的思维链(Chain of Thought,CoT)输出,在数学推理、代码生成等复杂任务上取得了显著进展。1 这类模型,如 OpenAI o1/o3 和 DeepSeek R1,通过在生成最终答案前产生内部思考标记,实现了对复杂问题的深度分析。然而,这种”慢思考”(Slow Thinking)范式带来了一个关键挑战:推理轨迹的长度无法被有效控制。
在实际部署场景中,推理时间的预算(Token 数量、延迟、计算资源)往往受到严格约束。传统方法在面对这些约束时存在明显不足:
- 朴素截断:在固定 Token 数后直接截断输出,可能导致解决方案阶段被意外截断,产生不完整或无效的答案。
- Budget Forcing(S1):通过提示模型在特定 Token 数后输出结束标记,虽然优于朴素截断,但仍未从根本上解决资源分配问题。
- Long2Short 方法:通过强化学习减少推理长度,但往往需要大量训练资源且性能会有明显下降。
Elastic Reasoning(弹性推理)应运而生,它提供了一种原则性且实用的解决方案,使大语言推理模型能够在严格约束下生成准确、高效的思维链输出。
2. 问题定义
2.1 推理输出的结构
给定输入提示 ,LRM 生成输出序列 ,其中:
- :包含中间推理步骤(包裹在
<think>和</think>标记之间),这部分通常占据超过 90% 的 Token 数量。 - :包含最终解决方案,是对推理过程的总结与答案输出。
整体生成结构可表示为:
2.2 资源约束下的生成问题
设 为总生成预算(Token 数量上限),传统方法要求:
其中 表示生成的 Token 总数。这种约束面临的核心问题是:如何确保解决方案阶段不被截断,同时保持推理质量?
3. Elastic Reasoning 框架
Elastic Reasoning 的核心思想是显式分离推理过程,将思维链明确划分为两个具有独立预算的阶段。
3.1 分离预算机制(Separate Budgeting)
分离预算机制将总预算 显式划分为两个独立部分:
其中:
- :思考阶段(Thinking Phase)的预算
- :解决方案阶段(Solution Phase)的预算
推理过程:
- 模型在
<think>块内开始生成推理内容 - 情况一:模型在达到预算 之前自然输出
</think>,则立即进入解决方案阶段 - 情况二:预算 耗尽但模型仍未输出
</think>,则强制插入</think>标记终止思考阶段 - 模型在解决方案阶段继续生成,最多使用 个 Token
这种方法确保了解决方案阶段始终拥有保底的 Token 配额,避免了被意外截断的风险。
3.2 关键洞察
一个关键观察是:即使思考阶段被强制终止,模型仍然具备生成连贯且正确解决方案的能力。这意味着我们可以充分利用有限资源,在不完整的推理轨迹下仍能产出有效答案。
实验结果证明,分离预算机制在各种生成预算下都显著优于朴素截断和 S1 方法。
4. 预算约束 rollout 训练策略
4.1 训练目标
虽然分离预算机制确保了解决方案阶段的完整性,但在复杂任务(尤其是代码生成)上,不完整的思考过程仍可能导致显著的性能下降。为解决这一问题,Elastic Reasoning 引入了预算约束 rollout(Budget-Constrained Rollout)训练策略。
该策略是一种强化学习微调程序,通过在推理预算约束下训练模型,使其能够在有限的思考预算内产生更有效、更简洁的推理。
4.2 基于 GRPO 的优化
Elastic Reasoning 采用 GRPO(Group Relative Policy Optimization) 作为强化学习算法。
设 为由参数 化的语言模型策略,生成响应 ,满足总预算约束 。
Budget-Constrained Rollout 过程:
- 策略 rollout 推理段 ,最多使用 个 Token
- 若模型在达到 前自然输出
</think>,则正常进入解决方案生成 - 否则,在达到 后强制插入
</think> - 模型使用剩余 个 Token 生成解决方案段
设 为任务特定的奖励函数,训练目标为最大化期望奖励:
4.3 GRPO 梯度估计
使用 GRPO 进行优化的梯度估计器为:
其中优势函数 定义为:
4.4 训练效率
在训练设置中,预算对固定为 ,即各 1024 个 Token。这种设置兼顾了简单性和效率。
令人惊讶的发现是,经过训练的策略能够泛化到广泛的未见预算配置,无需额外微调。这意味着 Elastic Reasoning 促使模型内化了一种灵活的推理策略,能够适应不同的资源约束。
训练效率对比:
| 方法 | 训练步数 | 最大响应长度 |
|---|---|---|
| Elastic Reasoning (E1) | 200 | 2K |
| L1-Exact | 700 | 4K |
| L1-Max | 820 | 4K |
E1 方法在仅需约 1/4 训练步数的情况下,即可达到与基线方法相当甚至更好的性能。
5. 数学与代码推理实验
5.1 实验设置
基础模型:
- E1-Math-1.5B:基于 DeepScaleR-1.5B-Preview 微调(来自 DeepSeekR1-Distill-Qwen-1.5B)
- E1-Code-14B:基于 DeepCoder-14B-Preview 微调(来自 DeepSeekR1-Distill-Qwen-14B)
数学领域训练数据:
- AIME (1984-2023)
- AMC
- Omni-Math
- STILL
代码领域训练数据:
- TACO
- SYNTHETIC-1
- LiveCodeBench (2023/05/01 - 2024/07/31)
评估基准:
| 领域 | 基准 |
|---|---|
| 数学推理 | AIME 2024, MATH500, AMC, Olympiad-Bench, Minerva Math |
| 代码生成 | LiveCodeBench (2024/08/01 - 2025/02/01), Codeforces, HumanEval+ |
5.2 数学推理结果
在 AIME2024 验证集上的 Pass@1 准确率和奖励曲线显示,E1-Math-1.5B 在训练过程中快速收敛:奖励稳步增加,约 150 步后开始稳定;验证准确率从约 7% 提升至 20%。
核心发现:
- MATH500:E1-Math-1.5B 以 1619 Token/问题的使用量达到 83.6% Pass@1 准确率,而 L1-Exact 需要 1959 Token 才能达到 79.9%
- 无约束性能:E1-Math-1.5B 在无推理预算约束时,在所有基准上均超越基线方法
- 性能保持:在 AIME2024 上,E1-Math-1.5B 相比原始模型性能下降仅 6.0%,而 L1-Max 下降 12.9%、L1-Exact 下降 16.8%
- Token 节省:相比 DeepScaleR-1.5B,E1-Math-1.5B 在各数据集上平均节省超过 30% 的 Token 使用量
5.3 代码推理结果
代码推理任务展示了更明显的效果:
- 原始 DeepCoder-14B-Preview:当推理预算低于 4K 时,在不完整思考下准确率急剧下降,始终低于 10%
- E1-Code-14B:展现出优异的可扩展性,性能随推理预算增加稳步提升
这充分证明了预算约束 rollout 训练策略在使模型自适应受限思考方面的有效性。
值得注意的是,E1-Code-14B 在 LiveCodeBench 上实现了 0.3% 的性能提升,即使在无约束设置下也是如此。
6. PyTorch 实现
以下是基于 GRPO 的 Elastic Reasoning 训练的核心 PyTorch 实现:
import torch
import torch.nn as nn
from torch.distributions import Categorical
from typing import Dict, List, Tuple, Optional
import math
class ElasticReasoningGRPO:
"""
Elastic Reasoning 训练器:基于 GRPO 的预算约束 rollout 优化
核心思想:将推理过程分离为思考阶段与解决方案阶段,
通过预算约束 rollout 训练模型在受限思考下生成高质量答案。
"""
def __init__(
self,
model: nn.Module,
thinking_budget: int = 1024,
solution_budget: int = 1024,
think_token_id: int = -1,
end_think_token_id: int = -1,
learning_rate: float = 1e-6,
adv_temperature: float = 0.5,
):
self.model = model
self.thinking_budget = thinking_budget
self.solution_budget = solution_budget
self.think_token_id = think_token_id
self.end_think_token_id = end_think_token_id
self.learning_rate = learning_rate
self.adv_temperature = adv_temperature
# 使用 AdamW 优化器
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=learning_rate,
weight_decay=0.01
)
def budget_constrained_rollout(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Tuple[List[int], str]:
"""
预算约束 rollout:模拟分离预算推理过程
Args:
input_ids: 输入 token IDs [batch_size, seq_len]
attention_mask: 注意力掩码
Returns:
generated_ids: 生成的 token IDs 列表
phase_info: 阶段信息(用于调试)
"""
self.model.eval()
device = input_ids.device
generated = input_ids[0].tolist()
think_tokens = 0
solution_tokens = 0
in_thinking = True
phase_info = {"think_tokens": 0, "solution_tokens": 0}
# 特殊标记
think_start = self.think_token_id # <think> 的 token ID
think_end = self.end_think_token_id # </think> 的 token ID
with torch.no_grad():
max_length = self.thinking_budget + self.solution_budget + len(input_ids[0])
while len(generated) < max_length:
# 准备输入
input_tensor = torch.tensor([generated], device=device)
# 前向传播
if attention_mask is not None:
# 扩展 attention mask
cur_len = len(generated)
pad_len = cur_len - attention_mask.shape[1]
extended_mask = torch.cat([
attention_mask,
torch.ones(1, pad_len, device=device, dtype=torch.long)
], dim=1) if pad_len > 0 else attention_mask
outputs = self.model(input_tensor, attention_mask=extended_mask)
else:
outputs = self.model(input_tensor)
# 获取 logits 并采样
logits = outputs.logits[0, -1, :] / self.adv_temperature
probs = torch.softmax(logits, dim=-1)
dist = Categorical(probs)
next_token = dist.sample()
# 强制插入 end_think 标记(如果预算耗尽且仍在思考)
if in_thinking:
think_tokens += 1
if think_tokens >= self.thinking_budget:
# 预算耗尽,强制结束思考阶段
next_token = torch.tensor([think_end], device=device)
in_thinking = False
phase_info["think_tokens"] = think_tokens
else:
solution_tokens += 1
if solution_tokens >= self.solution_budget:
# 解决方案预算耗尽
break
generated.append(next_token.item())
# 检测思考阶段结束
if in_thinking and next_token.item() == think_end:
in_thinking = False
phase_info["think_tokens"] = think_tokens
# 检测序列结束
if next_token.item() == self.model.config.eos_token_id:
break
phase_info["solution_tokens"] = solution_tokens
return generated, phase_info
def compute_rewards(
self,
responses: List[List[int]],
answers: List[str],
is_correct: List[bool]
) -> torch.Tensor:
"""
计算奖励函数
Args:
responses: 响应序列列表
answers: 标准答案
is_correct: 正确性标记
Returns:
rewards: 奖励张量
"""
rewards = []
for correct in is_correct:
# 正确 = 1.0,错误 = 0.0
rewards.append(1.0 if correct else 0.0)
return torch.tensor(rewards, dtype=torch.float32)
def grpo_update(
self,
prompts: List[str],
responses: List[List[int]],
old_log_probs: torch.Tensor,
advantages: torch.Tensor
) -> Dict[str, float]:
"""
GRPO 更新步骤
Args:
prompts: 输入提示
responses: 响应序列
old_log_probs: 旧策略的对数概率
advantages: 优势函数
Returns:
stats: 训练统计信息
"""
self.model.train()
# 重新计算对数概率
batch_size = len(responses)
log_probs_sum = torch.zeros(batch_size, device=old_log_probs.device)
for i, response in enumerate(responses):
if len(response) == 0:
continue
input_ids = torch.tensor([response[:-1]], device=old_log_probs.device)
with torch.set_grad_enabled(True):
outputs = self.model(input_ids)
logits = outputs.logits[0] # [seq_len, vocab_size]
log_probs = torch.log_softmax(logits, dim=-1)
# 计算响应部分的对数概率
response_log_probs = log_probs[torch.arange(len(response)-1, device=log_probs.device),
torch.tensor(response[1:], device=log_probs.device)]
log_probs_sum[i] = response_log_probs.sum()
# 计算策略梯度
ratio = torch.exp(log_probs_sum - old_log_probs)
clipped_ratio = torch.clamp(ratio, 1 - 0.2, 1 + 0.2)
policy_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()
# 添加 KL 散度正则化(可选)
kl_loss = 0.01 * (log_probs_sum - old_log_probs).pow(2).mean()
total_loss = policy_loss + kl_loss
# 反向传播
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
return {
"total_loss": total_loss.item(),
"policy_loss": policy_loss.item(),
"kl_loss": kl_loss.item(),
"mean_ratio": ratio.mean().item()
}
def train_step(
self,
batch: Dict[str, any],
reward_fn: callable
) -> Dict[str, float]:
"""
单步训练
Args:
batch: 包含 prompts 和 answers 的批次
reward_fn: 奖励函数
Returns:
stats: 训练统计
"""
prompts = batch["prompts"]
answers = batch["answers"]
input_ids_list = batch["input_ids"]
# Group 内的采样数量
group_size = 4
all_responses = []
all_log_probs = []
all_rewards = []
for prompt, answer, input_ids in zip(prompts, answers, input_ids_list):
# 使用当前策略进行 rollout
group_responses = []
group_log_probs = []
for _ in range(group_size):
response, phase_info = self.budget_constrained_rollout(input_ids)
group_responses.append(response)
# 计算奖励
is_correct = reward_fn(group_responses, answer)
rewards = self.compute_rewards(group_responses, answer, is_correct)
all_responses.extend(group_responses)
all_rewards.extend(rewards.tolist())
# 转换为张量
rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32)
# 计算 GRPO 优势函数(组内相对优势)
advantages = []
for i in range(0, len(rewards_tensor), group_size):
group_rewards = rewards_tensor[i:i+group_size]
group_mean = group_rewards.mean()
group_std = group_rewards.std() + 1e-8
group_adv = (group_rewards - group_mean) / group_std
advantages.extend(group_adv.tolist())
advantages_tensor = torch.tensor(advantages, dtype=torch.float32)
# 旧策略的对数概率(简化版本)
old_log_probs = torch.zeros(len(all_responses), device=rewards_tensor.device)
# 执行 GRPO 更新
stats = self.grpo_update(prompts, all_responses, old_log_probs, advantages_tensor)
stats["mean_reward"] = rewards_tensor.mean().item()
stats["correct_rate"] = (rewards_tensor > 0).float().mean().item()
return stats
class SeparateBudgeting:
"""
分离预算推理:用于推理阶段的预算分配
这是 Elastic Reasoning 的推理时组件,
可以在任意训练好的推理模型上使用。
"""
def __init__(
self,
model: nn.Module,
thinking_budget: int = 1024,
solution_budget: int = 1024,
think_start_token: str = "<think>",
think_end_token: str = "</think>"
):
self.model = model
self.thinking_budget = thinking_budget
self.solution_budget = solution_budget
self.think_start_token = think_start_token
self.think_end_token = think_end_token
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
temperature: float = 0.7,
top_p: float = 0.9
) -> Dict[str, any]:
"""
使用分离预算策略生成响应
Args:
input_ids: 输入 token IDs
temperature: 采样温度
top_p: Nucleus 采样阈值
Returns:
result: 包含完整响应和阶段信息的字典
"""
device = input_ids.device
generated = input_ids[0].tolist()
# 获取特殊 token IDs
eos_token = self.model.config.eos_token_id
# 标记当前阶段
in_thinking = True
think_tokens = 0
solution_tokens = 0
max_length = self.thinking_budget + self.solution_budget + len(input_ids[0])
while len(generated) < max_length:
# 准备输入
input_tensor = torch.tensor([generated], device=device)
# 前向传播
outputs = self.model(input_tensor)
logits = outputs.logits[0, -1, :] / temperature
# Top-p 采样
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# 保留概率和超过 top_p 的 token
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
0, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = float('-inf')
# 采样
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
# 强制终止逻辑
if in_thinking:
think_tokens += 1
if think_tokens >= self.thinking_budget:
# 思考预算耗尽,强制插入 end_think 标记
# 简化处理:直接跳转到解决方案阶段
in_thinking = False
if not in_thinking:
solution_tokens += 1
if solution_tokens >= self.solution_budget:
break
generated.append(next_token)
# 检测思考结束
# 注意:实际实现中需要根据具体 tokenization 调整
if next_token == eos_token:
break
return {
"response": generated,
"think_tokens": think_tokens,
"solution_tokens": solution_tokens,
"total_tokens": len(generated) - len(input_ids[0])
}6.1 使用示例
# 初始化模型和训练器
model = load_pretrained_reasoning_model("DeepScaleR-1.5B")
trainer = ElasticReasoningGRPO(
model=model,
thinking_budget=1024, # 1K tokens for thinking
solution_budget=1024, # 1K tokens for solution
)
# 自定义奖励函数
def math_reward_fn(responses, answer):
results = []
for resp in responses:
# 提取答案并与标准答案比较
extracted = extract_answer(resp)
results.append(check_answer(extracted, answer))
return results
# 训练循环
for step in range(200):
batch = sample_batch(dataset, batch_size=8)
stats = trainer.train_step(batch, math_reward_fn)
if step % 10 == 0:
print(f"Step {step}: reward={stats['mean_reward']:.3f}, "
f"correct_rate={stats['correct_rate']:.3f}")
# 推理时使用分离预算
inferencer = SeparateBudgeting(
model=model,
thinking_budget=2048, # 可根据需求调整
solution_budget=1024
)
result = inferencer.generate(input_ids)
print(f"Thinking tokens: {result['think_tokens']}")
print(f"Solution tokens: {result['solution_tokens']}")7. 对数-线性缩放定律
实验发现,Elastic Reasoning 展现出清晰的对数-线性缩放模式:性能随生成推理 Token 数量的对数近似线性提升。
这一发现与 L1、S1、O1 等方法的观察一致,表明:
其中 和 为任务相关的常数。这意味着我们可以通过控制推理预算来预测性地调节模型性能。
8. 与相关工作的对比
| 方法 | 核心思想 | 优势 | 局限 |
|---|---|---|---|
| 朴素截断 | 直接在固定 Token 处截断 | 简单 | 解决方案易被截断 |
| S1 (Budget Forcing) | 强制在特定 Token 输出结束标记 | 保留解决方案 | 性能下降明显 |
| L1 | 强化学习 + 长度控制 | 性能较好 | 训练成本高 |
| Elastic Reasoning | 分离预算 + 预算约束 rollout | 高效、泛化强 | 需要特殊训练 |
Elastic Reasoning 的独特优势在于:
- 双重保障:分离预算确保解决方案不被截断,训练策略确保推理质量
- 高效训练:仅需 200 步即可收敛
- 良好泛化:训练于单一预算配置,可泛化至任意未见预算
- 无约束提升:即使在无预算约束下也能提升或保持性能
9. 实际应用场景
Elastic Reasoning 适用于多种实际部署场景:
9.1 延迟敏感的在线服务
在需要快速响应的在线应用中(如实时问答、对话系统),可以通过限制思考预算来控制响应延迟:
# 实时对话场景:限制总生成时间为 100ms
thinking_budget = 512
solution_budget = 256
inferencer = SeparateBudgeting(model, thinking_budget, solution_budget)9.2 成本受限的企业应用
在 Token 成本受限的场景下,Elastic Reasoning 可显著降低推理成本:
- 30%+ Token 节省:在保持或提升性能的同时大幅减少 Token 使用
- 可预测成本:通过固定预算实现精确的成本控制
9.3 边缘设备部署
在算力受限的边缘设备上,分离预算机制使得复杂推理任务成为可能:
# 边缘设备场景:严格限制内存和计算
thinking_budget = 256
solution_budget = 128
model = quantize_for_edge(model)
inferencer = SeparateBudgeting(model, thinking_budget, solution_budget)10. 总结与展望
Elastic Reasoning 为大语言推理模型提供了一种原则性且实用的可控推理框架。其核心贡献包括:
- 分离预算机制:显式分离思考阶段与解决方案阶段,确保解决方案完整性
- 预算约束 rollout:通过 GRPO 训练,使模型能够自适应受限思考
- 卓越的效率:仅需 200 训练步即可达到 SOTA 性能
- 良好的泛化性:单一预算训练,泛化至任意未见预算配置
- 无约束性能提升:即使在无预算约束下也能产生更简洁、高效的推理
未来方向
- 自适应预算分配:根据问题难度动态调整思考/解决方案预算比例
- 多阶段推理:扩展至多阶段分解的复杂推理任务
- 与其他推理增强技术结合:如 思维链提示、验证器驱动的自改进
Elastic Reasoning 为可扩展的推理系统开辟了新方向,是实现高效、可靠、受控推理的重要里程碑。
参考资料
Footnotes
-
Xu, Y., Dong, H., Wang, L., Sahoo, D., Li, J., & Xiong, C. (2025). Scalable Chain of Thoughts via Elastic Reasoning. arXiv:2505.05315. https://arxiv.org/abs/2505.05315 ↩