1. 引言

大语言推理模型(Large Reasoning Models,LRMs)通过生成扩展的思维链(Chain of Thought,CoT)输出,在数学推理、代码生成等复杂任务上取得了显著进展。1 这类模型,如 OpenAI o1/o3 和 DeepSeek R1,通过在生成最终答案前产生内部思考标记,实现了对复杂问题的深度分析。然而,这种”慢思考”(Slow Thinking)范式带来了一个关键挑战:推理轨迹的长度无法被有效控制

在实际部署场景中,推理时间的预算(Token 数量、延迟、计算资源)往往受到严格约束。传统方法在面对这些约束时存在明显不足:

  • 朴素截断:在固定 Token 数后直接截断输出,可能导致解决方案阶段被意外截断,产生不完整或无效的答案。
  • Budget Forcing(S1):通过提示模型在特定 Token 数后输出结束标记,虽然优于朴素截断,但仍未从根本上解决资源分配问题。
  • Long2Short 方法:通过强化学习减少推理长度,但往往需要大量训练资源且性能会有明显下降。

Elastic Reasoning(弹性推理)应运而生,它提供了一种原则性且实用的解决方案,使大语言推理模型能够在严格约束下生成准确、高效的思维链输出。

2. 问题定义

2.1 推理输出的结构

给定输入提示 ,LRM 生成输出序列 ,其中:

  • :包含中间推理步骤(包裹在 <think></think> 标记之间),这部分通常占据超过 90% 的 Token 数量。
  • :包含最终解决方案,是对推理过程的总结与答案输出。

整体生成结构可表示为:

2.2 资源约束下的生成问题

为总生成预算(Token 数量上限),传统方法要求:

其中 表示生成的 Token 总数。这种约束面临的核心问题是:如何确保解决方案阶段不被截断,同时保持推理质量?

3. Elastic Reasoning 框架

Elastic Reasoning 的核心思想是显式分离推理过程,将思维链明确划分为两个具有独立预算的阶段。

3.1 分离预算机制(Separate Budgeting)

分离预算机制将总预算 显式划分为两个独立部分:

其中:

  • :思考阶段(Thinking Phase)的预算
  • :解决方案阶段(Solution Phase)的预算

推理过程

  1. 模型在 <think> 块内开始生成推理内容
  2. 情况一:模型在达到预算 之前自然输出 </think>,则立即进入解决方案阶段
  3. 情况二:预算 耗尽但模型仍未输出 </think>,则强制插入 </think> 标记终止思考阶段
  4. 模型在解决方案阶段继续生成,最多使用 个 Token

这种方法确保了解决方案阶段始终拥有保底的 Token 配额,避免了被意外截断的风险。

3.2 关键洞察

一个关键观察是:即使思考阶段被强制终止,模型仍然具备生成连贯且正确解决方案的能力。这意味着我们可以充分利用有限资源,在不完整的推理轨迹下仍能产出有效答案。

实验结果证明,分离预算机制在各种生成预算下都显著优于朴素截断和 S1 方法。

4. 预算约束 rollout 训练策略

4.1 训练目标

虽然分离预算机制确保了解决方案阶段的完整性,但在复杂任务(尤其是代码生成)上,不完整的思考过程仍可能导致显著的性能下降。为解决这一问题,Elastic Reasoning 引入了预算约束 rollout(Budget-Constrained Rollout)训练策略。

该策略是一种强化学习微调程序,通过在推理预算约束下训练模型,使其能够在有限的思考预算内产生更有效、更简洁的推理。

4.2 基于 GRPO 的优化

Elastic Reasoning 采用 GRPO(Group Relative Policy Optimization) 作为强化学习算法。

为由参数 化的语言模型策略,生成响应 ,满足总预算约束

Budget-Constrained Rollout 过程

  1. 策略 rollout 推理段 ,最多使用 个 Token
  2. 若模型在达到 前自然输出 </think>,则正常进入解决方案生成
  3. 否则,在达到 后强制插入 </think>
  4. 模型使用剩余 个 Token 生成解决方案段

为任务特定的奖励函数,训练目标为最大化期望奖励:

4.3 GRPO 梯度估计

使用 GRPO 进行优化的梯度估计器为:

其中优势函数 定义为:

4.4 训练效率

在训练设置中,预算对固定为 ,即各 1024 个 Token。这种设置兼顾了简单性和效率。

令人惊讶的发现是,经过训练的策略能够泛化到广泛的未见预算配置,无需额外微调。这意味着 Elastic Reasoning 促使模型内化了一种灵活的推理策略,能够适应不同的资源约束。

训练效率对比

方法训练步数最大响应长度
Elastic Reasoning (E1)2002K
L1-Exact7004K
L1-Max8204K

E1 方法在仅需约 1/4 训练步数的情况下,即可达到与基线方法相当甚至更好的性能。

5. 数学与代码推理实验

5.1 实验设置

基础模型

  • E1-Math-1.5B:基于 DeepScaleR-1.5B-Preview 微调(来自 DeepSeekR1-Distill-Qwen-1.5B)
  • E1-Code-14B:基于 DeepCoder-14B-Preview 微调(来自 DeepSeekR1-Distill-Qwen-14B)

数学领域训练数据

  • AIME (1984-2023)
  • AMC
  • Omni-Math
  • STILL

代码领域训练数据

  • TACO
  • SYNTHETIC-1
  • LiveCodeBench (2023/05/01 - 2024/07/31)

评估基准

领域基准
数学推理AIME 2024, MATH500, AMC, Olympiad-Bench, Minerva Math
代码生成LiveCodeBench (2024/08/01 - 2025/02/01), Codeforces, HumanEval+

5.2 数学推理结果

在 AIME2024 验证集上的 Pass@1 准确率和奖励曲线显示,E1-Math-1.5B 在训练过程中快速收敛:奖励稳步增加,约 150 步后开始稳定;验证准确率从约 7% 提升至 20%。

核心发现

  • MATH500:E1-Math-1.5B 以 1619 Token/问题的使用量达到 83.6% Pass@1 准确率,而 L1-Exact 需要 1959 Token 才能达到 79.9%
  • 无约束性能:E1-Math-1.5B 在无推理预算约束时,在所有基准上均超越基线方法
  • 性能保持:在 AIME2024 上,E1-Math-1.5B 相比原始模型性能下降仅 6.0%,而 L1-Max 下降 12.9%、L1-Exact 下降 16.8%
  • Token 节省:相比 DeepScaleR-1.5B,E1-Math-1.5B 在各数据集上平均节省超过 30% 的 Token 使用量

5.3 代码推理结果

代码推理任务展示了更明显的效果:

  • 原始 DeepCoder-14B-Preview:当推理预算低于 4K 时,在不完整思考下准确率急剧下降,始终低于 10%
  • E1-Code-14B:展现出优异的可扩展性,性能随推理预算增加稳步提升

这充分证明了预算约束 rollout 训练策略在使模型自适应受限思考方面的有效性。

值得注意的是,E1-Code-14B 在 LiveCodeBench 上实现了 0.3% 的性能提升,即使在无约束设置下也是如此。

6. PyTorch 实现

以下是基于 GRPO 的 Elastic Reasoning 训练的核心 PyTorch 实现:

import torch
import torch.nn as nn
from torch.distributions import Categorical
from typing import Dict, List, Tuple, Optional
import math
 
class ElasticReasoningGRPO:
    """
    Elastic Reasoning 训练器:基于 GRPO 的预算约束 rollout 优化
    
    核心思想:将推理过程分离为思考阶段与解决方案阶段,
    通过预算约束 rollout 训练模型在受限思考下生成高质量答案。
    """
    
    def __init__(
        self,
        model: nn.Module,
        thinking_budget: int = 1024,
        solution_budget: int = 1024,
        think_token_id: int = -1,
        end_think_token_id: int = -1,
        learning_rate: float = 1e-6,
        adv_temperature: float = 0.5,
    ):
        self.model = model
        self.thinking_budget = thinking_budget
        self.solution_budget = solution_budget
        self.think_token_id = think_token_id
        self.end_think_token_id = end_think_token_id
        self.learning_rate = learning_rate
        self.adv_temperature = adv_temperature
        
        # 使用 AdamW 优化器
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=0.01
        )
        
    def budget_constrained_rollout(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> Tuple[List[int], str]:
        """
        预算约束 rollout:模拟分离预算推理过程
        
        Args:
            input_ids: 输入 token IDs [batch_size, seq_len]
            attention_mask: 注意力掩码
            
        Returns:
            generated_ids: 生成的 token IDs 列表
            phase_info: 阶段信息(用于调试)
        """
        self.model.eval()
        device = input_ids.device
        generated = input_ids[0].tolist()
        think_tokens = 0
        solution_tokens = 0
        in_thinking = True
        phase_info = {"think_tokens": 0, "solution_tokens": 0}
        
        # 特殊标记
        think_start = self.think_token_id  # <think> 的 token ID
        think_end = self.end_think_token_id  # </think> 的 token ID
        
        with torch.no_grad():
            max_length = self.thinking_budget + self.solution_budget + len(input_ids[0])
            
            while len(generated) < max_length:
                # 准备输入
                input_tensor = torch.tensor([generated], device=device)
                
                # 前向传播
                if attention_mask is not None:
                    # 扩展 attention mask
                    cur_len = len(generated)
                    pad_len = cur_len - attention_mask.shape[1]
                    extended_mask = torch.cat([
                        attention_mask,
                        torch.ones(1, pad_len, device=device, dtype=torch.long)
                    ], dim=1) if pad_len > 0 else attention_mask
                    outputs = self.model(input_tensor, attention_mask=extended_mask)
                else:
                    outputs = self.model(input_tensor)
                
                # 获取 logits 并采样
                logits = outputs.logits[0, -1, :] / self.adv_temperature
                probs = torch.softmax(logits, dim=-1)
                dist = Categorical(probs)
                next_token = dist.sample()
                
                # 强制插入 end_think 标记(如果预算耗尽且仍在思考)
                if in_thinking:
                    think_tokens += 1
                    if think_tokens >= self.thinking_budget:
                        # 预算耗尽,强制结束思考阶段
                        next_token = torch.tensor([think_end], device=device)
                        in_thinking = False
                        phase_info["think_tokens"] = think_tokens
                else:
                    solution_tokens += 1
                    if solution_tokens >= self.solution_budget:
                        # 解决方案预算耗尽
                        break
                
                generated.append(next_token.item())
                
                # 检测思考阶段结束
                if in_thinking and next_token.item() == think_end:
                    in_thinking = False
                    phase_info["think_tokens"] = think_tokens
                
                # 检测序列结束
                if next_token.item() == self.model.config.eos_token_id:
                    break
                    
        phase_info["solution_tokens"] = solution_tokens
        return generated, phase_info
    
    def compute_rewards(
        self,
        responses: List[List[int]],
        answers: List[str],
        is_correct: List[bool]
    ) -> torch.Tensor:
        """
        计算奖励函数
        
        Args:
            responses: 响应序列列表
            answers: 标准答案
            is_correct: 正确性标记
            
        Returns:
            rewards: 奖励张量
        """
        rewards = []
        for correct in is_correct:
            # 正确 = 1.0,错误 = 0.0
            rewards.append(1.0 if correct else 0.0)
        return torch.tensor(rewards, dtype=torch.float32)
    
    def grpo_update(
        self,
        prompts: List[str],
        responses: List[List[int]],
        old_log_probs: torch.Tensor,
        advantages: torch.Tensor
    ) -> Dict[str, float]:
        """
        GRPO 更新步骤
        
        Args:
            prompts: 输入提示
            responses: 响应序列
            old_log_probs: 旧策略的对数概率
            advantages: 优势函数
            
        Returns:
            stats: 训练统计信息
        """
        self.model.train()
        
        # 重新计算对数概率
        batch_size = len(responses)
        log_probs_sum = torch.zeros(batch_size, device=old_log_probs.device)
        
        for i, response in enumerate(responses):
            if len(response) == 0:
                continue
            input_ids = torch.tensor([response[:-1]], device=old_log_probs.device)
            
            with torch.set_grad_enabled(True):
                outputs = self.model(input_ids)
                logits = outputs.logits[0]  # [seq_len, vocab_size]
                log_probs = torch.log_softmax(logits, dim=-1)
                
                # 计算响应部分的对数概率
                response_log_probs = log_probs[torch.arange(len(response)-1, device=log_probs.device), 
                                               torch.tensor(response[1:], device=log_probs.device)]
                log_probs_sum[i] = response_log_probs.sum()
        
        # 计算策略梯度
        ratio = torch.exp(log_probs_sum - old_log_probs)
        clipped_ratio = torch.clamp(ratio, 1 - 0.2, 1 + 0.2)
        policy_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()
        
        # 添加 KL 散度正则化(可选)
        kl_loss = 0.01 * (log_probs_sum - old_log_probs).pow(2).mean()
        
        total_loss = policy_loss + kl_loss
        
        # 反向传播
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()
        
        return {
            "total_loss": total_loss.item(),
            "policy_loss": policy_loss.item(),
            "kl_loss": kl_loss.item(),
            "mean_ratio": ratio.mean().item()
        }
    
    def train_step(
        self,
        batch: Dict[str, any],
        reward_fn: callable
    ) -> Dict[str, float]:
        """
        单步训练
        
        Args:
            batch: 包含 prompts 和 answers 的批次
            reward_fn: 奖励函数
            
        Returns:
            stats: 训练统计
        """
        prompts = batch["prompts"]
        answers = batch["answers"]
        input_ids_list = batch["input_ids"]
        
        # Group 内的采样数量
        group_size = 4
        all_responses = []
        all_log_probs = []
        all_rewards = []
        
        for prompt, answer, input_ids in zip(prompts, answers, input_ids_list):
            # 使用当前策略进行 rollout
            group_responses = []
            group_log_probs = []
            
            for _ in range(group_size):
                response, phase_info = self.budget_constrained_rollout(input_ids)
                group_responses.append(response)
            
            # 计算奖励
            is_correct = reward_fn(group_responses, answer)
            rewards = self.compute_rewards(group_responses, answer, is_correct)
            
            all_responses.extend(group_responses)
            all_rewards.extend(rewards.tolist())
        
        # 转换为张量
        rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32)
        
        # 计算 GRPO 优势函数(组内相对优势)
        advantages = []
        for i in range(0, len(rewards_tensor), group_size):
            group_rewards = rewards_tensor[i:i+group_size]
            group_mean = group_rewards.mean()
            group_std = group_rewards.std() + 1e-8
            group_adv = (group_rewards - group_mean) / group_std
            advantages.extend(group_adv.tolist())
        
        advantages_tensor = torch.tensor(advantages, dtype=torch.float32)
        
        # 旧策略的对数概率(简化版本)
        old_log_probs = torch.zeros(len(all_responses), device=rewards_tensor.device)
        
        # 执行 GRPO 更新
        stats = self.grpo_update(prompts, all_responses, old_log_probs, advantages_tensor)
        stats["mean_reward"] = rewards_tensor.mean().item()
        stats["correct_rate"] = (rewards_tensor > 0).float().mean().item()
        
        return stats
 
 
class SeparateBudgeting:
    """
    分离预算推理:用于推理阶段的预算分配
    
    这是 Elastic Reasoning 的推理时组件,
    可以在任意训练好的推理模型上使用。
    """
    
    def __init__(
        self,
        model: nn.Module,
        thinking_budget: int = 1024,
        solution_budget: int = 1024,
        think_start_token: str = "<think>",
        think_end_token: str = "</think>"
    ):
        self.model = model
        self.thinking_budget = thinking_budget
        self.solution_budget = solution_budget
        self.think_start_token = think_start_token
        self.think_end_token = think_end_token
        
    @torch.no_grad()
    def generate(
        self,
        input_ids: torch.Tensor,
        temperature: float = 0.7,
        top_p: float = 0.9
    ) -> Dict[str, any]:
        """
        使用分离预算策略生成响应
        
        Args:
            input_ids: 输入 token IDs
            temperature: 采样温度
            top_p: Nucleus 采样阈值
            
        Returns:
            result: 包含完整响应和阶段信息的字典
        """
        device = input_ids.device
        generated = input_ids[0].tolist()
        
        # 获取特殊 token IDs
        eos_token = self.model.config.eos_token_id
        
        # 标记当前阶段
        in_thinking = True
        think_tokens = 0
        solution_tokens = 0
        
        max_length = self.thinking_budget + self.solution_budget + len(input_ids[0])
        
        while len(generated) < max_length:
            # 准备输入
            input_tensor = torch.tensor([generated], device=device)
            
            # 前向传播
            outputs = self.model(input_tensor)
            logits = outputs.logits[0, -1, :] / temperature
            
            # Top-p 采样
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            
            # 保留概率和超过 top_p 的 token
            sorted_indices_to_remove = cum_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            
            indices_to_remove = sorted_indices_to_remove.scatter(
                0, sorted_indices, sorted_indices_to_remove
            )
            logits[indices_to_remove] = float('-inf')
            
            # 采样
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            
            # 强制终止逻辑
            if in_thinking:
                think_tokens += 1
                
                if think_tokens >= self.thinking_budget:
                    # 思考预算耗尽,强制插入 end_think 标记
                    # 简化处理:直接跳转到解决方案阶段
                    in_thinking = False
            
            if not in_thinking:
                solution_tokens += 1
                
                if solution_tokens >= self.solution_budget:
                    break
            
            generated.append(next_token)
            
            # 检测思考结束
            # 注意:实际实现中需要根据具体 tokenization 调整
            if next_token == eos_token:
                break
        
        return {
            "response": generated,
            "think_tokens": think_tokens,
            "solution_tokens": solution_tokens,
            "total_tokens": len(generated) - len(input_ids[0])
        }

6.1 使用示例

# 初始化模型和训练器
model = load_pretrained_reasoning_model("DeepScaleR-1.5B")
trainer = ElasticReasoningGRPO(
    model=model,
    thinking_budget=1024,  # 1K tokens for thinking
    solution_budget=1024,  # 1K tokens for solution
)
 
# 自定义奖励函数
def math_reward_fn(responses, answer):
    results = []
    for resp in responses:
        # 提取答案并与标准答案比较
        extracted = extract_answer(resp)
        results.append(check_answer(extracted, answer))
    return results
 
# 训练循环
for step in range(200):
    batch = sample_batch(dataset, batch_size=8)
    stats = trainer.train_step(batch, math_reward_fn)
    
    if step % 10 == 0:
        print(f"Step {step}: reward={stats['mean_reward']:.3f}, "
              f"correct_rate={stats['correct_rate']:.3f}")
 
# 推理时使用分离预算
inferencer = SeparateBudgeting(
    model=model,
    thinking_budget=2048,  # 可根据需求调整
    solution_budget=1024
)
 
result = inferencer.generate(input_ids)
print(f"Thinking tokens: {result['think_tokens']}")
print(f"Solution tokens: {result['solution_tokens']}")

7. 对数-线性缩放定律

实验发现,Elastic Reasoning 展现出清晰的对数-线性缩放模式:性能随生成推理 Token 数量的对数近似线性提升。

这一发现与 L1、S1、O1 等方法的观察一致,表明:

其中 为任务相关的常数。这意味着我们可以通过控制推理预算来预测性地调节模型性能。

8. 与相关工作的对比

方法核心思想优势局限
朴素截断直接在固定 Token 处截断简单解决方案易被截断
S1 (Budget Forcing)强制在特定 Token 输出结束标记保留解决方案性能下降明显
L1强化学习 + 长度控制性能较好训练成本高
Elastic Reasoning分离预算 + 预算约束 rollout高效、泛化强需要特殊训练

Elastic Reasoning 的独特优势在于:

  1. 双重保障:分离预算确保解决方案不被截断,训练策略确保推理质量
  2. 高效训练:仅需 200 步即可收敛
  3. 良好泛化:训练于单一预算配置,可泛化至任意未见预算
  4. 无约束提升:即使在无预算约束下也能提升或保持性能

9. 实际应用场景

Elastic Reasoning 适用于多种实际部署场景:

9.1 延迟敏感的在线服务

在需要快速响应的在线应用中(如实时问答、对话系统),可以通过限制思考预算来控制响应延迟:

# 实时对话场景:限制总生成时间为 100ms
thinking_budget = 512
solution_budget = 256
 
inferencer = SeparateBudgeting(model, thinking_budget, solution_budget)

9.2 成本受限的企业应用

在 Token 成本受限的场景下,Elastic Reasoning 可显著降低推理成本:

  • 30%+ Token 节省:在保持或提升性能的同时大幅减少 Token 使用
  • 可预测成本:通过固定预算实现精确的成本控制

9.3 边缘设备部署

在算力受限的边缘设备上,分离预算机制使得复杂推理任务成为可能:

# 边缘设备场景:严格限制内存和计算
thinking_budget = 256
solution_budget = 128
 
model = quantize_for_edge(model)
inferencer = SeparateBudgeting(model, thinking_budget, solution_budget)

10. 总结与展望

Elastic Reasoning 为大语言推理模型提供了一种原则性且实用的可控推理框架。其核心贡献包括:

  1. 分离预算机制:显式分离思考阶段与解决方案阶段,确保解决方案完整性
  2. 预算约束 rollout:通过 GRPO 训练,使模型能够自适应受限思考
  3. 卓越的效率:仅需 200 训练步即可达到 SOTA 性能
  4. 良好的泛化性:单一预算训练,泛化至任意未见预算配置
  5. 无约束性能提升:即使在无预算约束下也能产生更简洁、高效的推理

未来方向

  • 自适应预算分配:根据问题难度动态调整思考/解决方案预算比例
  • 多阶段推理:扩展至多阶段分解的复杂推理任务
  • 与其他推理增强技术结合:如 思维链提示、验证器驱动的自改进

Elastic Reasoning 为可扩展的推理系统开辟了新方向,是实现高效、可靠、受控推理的重要里程碑。

参考资料


相关主题:思维链推理推理模型测试时计算缩放

Footnotes

  1. Xu, Y., Dong, H., Wang, L., Sahoo, D., Li, J., & Xiong, C. (2025). Scalable Chain of Thoughts via Elastic Reasoning. arXiv:2505.05315. https://arxiv.org/abs/2505.05315