大型推理模型的自我改进技术 (HSIR)
1. 问题背景
1.1 自我改进的兴起
自我改进训练(Self-Improvement Training)使大型推理模型(Large Reasoning Models, LRMs)能够:
- 通过自生成推理轨迹作为训练数据
- 无需外部监督即可提升自身能力
- 在推理任务上取得显著进步
1.2 现有方法的局限性
尽管自我改进展现出潜力,研究发现其存在严重问题:
| 问题 | 表现 | 影响 |
|---|---|---|
| 数据不平衡 | 大多数训练样本简单,挑战性样本稀缺 | 复杂任务表现不佳 |
| 过度思考 | 包含冗余推理步骤的样本被用于训练 | 模型崩溃风险 |
1.3 问题分析
1.3.1 数据不平衡
在自我生成的数据中:
这导致模型在简单任务上过度拟合,而在复杂任务上泛化能力不足。
1.3.2 过度思考(Overthinking)
许多样本包含冗余的推理步骤:
- 早期正确:在少量步骤后已得到正确答案
- 继续推理:模型继续生成不必要的步骤
- 错误累积:冗余步骤可能引入错误
- 训练污染:这些样本污染训练数据
2. HSIR方法详解
2.1 核心思想
HSIR(Harnessing Self-Improvement in large Reasoning models)通过两个简单而有效的策略解决上述问题:
- 验证后退出采样策略:缓解数据不平衡
- 内在多样性评分:量化并过滤过度思考样本
2.2 验证后退出采样策略
2.2.1 动机
传统采样方法在所有查询上使用相同的采样数量,无法有效收集复杂查询的多样化解。
2.2.2 方法设计
Verify-then-Exit采样:
对于每个查询Q:
1. 初始采样N个解答
2. 验证每个解答的正确性
for 解答 in sampled_solutions:
if is_correct(解答):
添加到训练集
else:
继续采样...
3. 重复直到收集到足够多的正确解答
或者达到最大采样次数
关键洞察:为复杂查询投入更多采样预算,以收集更多正确答案。
2.2.3 形式化
设 为查询 的 次采样解答集:
采样策略优化:
2.3 内在多样性评分
2.3.1 过度思考的量化
过度思考样本的共同特征:
- 推理轨迹过长
- 包含重复或循环的推理步骤
- 关键决策点后仍有大量冗余
2.3.2 ID分数定义
内在多样性(Intrinsic Diversity, ID)评分:
其中:
- :推理中导致答案变化的关键节点
- :完整推理轨迹长度
- :与多数投票答案的一致程度
2.3.3 过度思考过滤
设定阈值 进行过滤:
低于阈值的样本被判定为过度思考,不用于训练。
3. H-GRPO算法
3.1 GRPO基础回顾
GRPO(Group Relative Policy Optimization)是一种强化学习算法,在推理训练中广泛使用。
3.2 H-GRPO增强
H-GRPO在标准GRPO基础上引入内在多样性作为外部奖励:
其中:
- :任务原始奖励(正确性)
- :多样性权重系数
- :内在多样性分数
3.3 训练目标
4. 实验结果
4.1 主要结果
在多个推理基准上的性能提升:
| 基准 | 基线 | GRPO | HSIR | 提升 |
|---|---|---|---|---|
| MATH | 72.3 | 76.8 | 82.1 | +10.9% |
| AIME | 45.2 | 48.1 | 52.3 | +7.1% |
| ARC-C | 89.5 | 91.2 | 93.8 | +4.3% |
| 平均 | - | - | - | +10.9% |
4.2 推理效率
推理开销显著降低:
| 模型 | 推理步数 | 相对开销 |
|---|---|---|
| 基线 | 128 | 1.0× |
| 标准GRPO | 156 | 1.22× |
| HSIR | 90 | 0.70× |
推理效率提升:42.4%
4.3 消融实验
4.3.1 Verify-then-Exit的效果
| 方法 | 复杂查询准确率 | 简单查询准确率 |
|---|---|---|
| 标准采样 | 45.2% | 91.5% |
| Verify-then-Exit | 58.7% | 90.8% |
4.3.2 内在多样性的效果
| 方法 | 过度思考比例 | 训练稳定性 |
|---|---|---|
| 无过滤 | 38.5% | 低 |
| ID过滤 | 12.1% | 高 |
5. PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional, Dict
from dataclasses import dataclass
@dataclass
class ReasoningSample:
"""A reasoning sample with its metadata."""
query: str
solution: str
answer: str
is_correct: bool
reasoning_steps: List[str]
key_decision_points: int
id_score: float
class VerifyThenExitSampler:
"""
Verify-then-Exit sampling strategy.
Collects more correct solutions for difficult queries.
"""
def __init__(
self,
model,
verifier,
max_samples_per_query: int = 32,
min_correct_ratio: float = 0.3,
early_exit_threshold: int = 8
):
self.model = model
self.verifier = verifier
self.max_samples = max_samples_per_query
self.min_correct_ratio = min_correct_ratio
self.early_exit_threshold = early_exit_threshold
def sample(
self,
query: str,
difficulty_hint: Optional[str] = None
) -> List[ReasoningSample]:
"""
Sample solutions with verify-then-exit strategy.
"""
samples = []
correct_count = 0
target_correct = int(self.max_samples * self.min_correct_ratio)
# Adaptive sampling based on difficulty
n_samples = self.max_samples
if difficulty_hint == "hard":
n_samples = int(self.max_samples * 1.5)
elif difficulty_hint == "easy":
n_samples = int(self.max_samples * 0.7)
for i in range(n_samples):
# Generate solution
solution, answer, steps = self.model.generate(
query,
return_reasoning=True
)
# Verify correctness
is_correct = self.verifier.check(query, solution, answer)
# Calculate ID score
id_score = self._calculate_id_score(solution, steps)
sample = ReasoningSample(
query=query,
solution=solution,
answer=answer,
is_correct=is_correct,
reasoning_steps=steps,
key_decision_points=self._count_key_decisions(steps),
id_score=id_score
)
samples.append(sample)
if is_correct:
correct_count += 1
# Early exit if we have enough correct samples
if correct_count >= target_correct:
break
return samples
def _calculate_id_score(
self,
solution: str,
steps: List[str]
) -> float:
"""
Calculate Intrinsic Diversity score.
"""
total_steps = len(steps)
if total_steps == 0:
return 0.0
# Count key decision points
key_points = self._count_key_decisions(steps)
# Calculate answer consistency (simplified)
# In practice, compare with majority vote
answer_consistency = 1.0 # Placeholder
id_score = (key_points / total_steps) * answer_consistency
return id_score
def _count_key_decisions(self, steps: List[str]) -> int:
"""
Count key decision points in reasoning steps.
"""
# Key decision indicators
indicators = [
"therefore", "so", "thus",
"conclude", "since", "because",
"if and only if", "equivalent to",
"the answer is", "final answer"
]
count = 0
for step in steps:
step_lower = step.lower()
for indicator in indicators:
if indicator in step_lower:
count += 1
break
return max(count, 1) # At least 1 decision point
class H_GRPO:
"""
H-GRPO: Enhanced GRPO with Intrinsic Diversity reward.
"""
def __init__(
self,
model,
id_weight: float = 0.5,
id_threshold: float = 0.3
):
self.model = model
self.id_weight = id_weight
self.id_threshold = id_threshold
def compute_reward(
self,
sample: ReasoningSample,
baseline_samples: List[ReasoningSample]
) -> Tuple[float, bool]:
"""
Compute combined reward with ID filtering.
Returns:
reward: Combined reward value
keep: Whether to keep sample for training
"""
# Task reward
task_reward = 1.0 if sample.is_correct else 0.0
# ID reward
id_reward = self.id_weight * sample.id_score
# Combined reward
combined_reward = task_reward + id_reward
# Filtering decision
keep = sample.id_score > self.id_threshold
return combined_reward, keep
def update(
self,
query: str,
samples: List[ReasoningSample],
old_log_probs: torch.Tensor,
clip_eps: float = 0.2,
lr: float = 1e-5
):
"""
Update policy using H-GRPO.
"""
# Compute rewards and filter samples
filtered_samples = []
rewards = []
for sample in samples:
reward, keep = self.compute_reward(sample, samples)
if keep:
filtered_samples.append(sample)
rewards.append(reward)
if len(filtered_samples) == 0:
return # No valid samples to train on
rewards = torch.tensor(rewards, device=old_log_probs.device)
# Normalize rewards
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
# Compute policy gradient loss
new_log_probs = self.model.get_log_probs(query, filtered_samples)
# GRPO-style importance sampling
ratio = torch.exp(new_log_probs - old_log_probs)
clipped_ratio = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps)
# PPO-style objective with clipping
loss = -torch.min(ratio * rewards, clipped_ratio * rewards).mean()
# Update model
self.model.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
for param in self.model.parameters():
param.data -= lr * param.grad
return loss.item()
class SelfImprovementTrainer:
"""
Complete self-improvement training pipeline.
"""
def __init__(
self,
model,
verifier,
optimizer,
id_weight: float = 0.5,
id_threshold: float = 0.3
):
self.model = model
self.sampler = VerifyThenExitSampler(model, verifier)
self.hgrpo = H_GRPO(model, id_weight, id_threshold)
self.optimizer = optimizer
def train_step(self, batch_queries: List[str]) -> Dict[str, float]:
"""
Single training step.
"""
total_loss = 0.0
total_samples = 0
correct_before = 0
correct_after = 0
for query in batch_queries:
# Sample with verify-then-exit
samples = self.sampler.sample(query)
if len(samples) == 0:
continue
# Count correct before training
correct_before += sum(1 for s in samples if s.is_correct)
# Update with H-GRPO
with torch.no_grad():
old_log_probs = self.model.get_log_probs(
query, samples
)
loss = self.hgrpo.update(query, samples, old_log_probs)
if loss is not None:
total_loss += loss
total_samples += 1
# Verify after training
new_samples = self.sampler.sample(query, n_samples=4)
correct_after += sum(1 for s in new_samples if s.is_correct)
return {
'avg_loss': total_loss / max(total_samples, 1),
'correct_before': correct_before,
'correct_after': correct_after,
'improvement': correct_after - correct_before
}6. 关键洞察
6.1 数据不平衡的解决方案
核心思想:复杂查询需要更多采样预算
- 自适应采样:根据查询难度调整采样数量
- 验证优先:先验证正确性,再决定是否继续采样
- 目标收集:优先收集正确答案用于训练
6.2 过度思考的识别与过滤
核心思想:冗长≠准确
- 内在多样性评分:量化推理效率
- 关键决策点:识别真正重要的推理步骤
- 自适应阈值:根据任务调整过滤强度
6.3 强化学习的增强
核心思想:将效率纳入奖励
- 任务奖励:正确性导向
- 多样性奖励:简洁性导向
- 综合目标:兼顾效果和效率
7. 总结与展望
7.1 核心贡献
- 问题诊断:系统分析自我改进中的数据不平衡和过度思考
- 解决方案:Verify-then-Exit采样 + 内在多样性评分
- 算法增强:H-GRPO将效率纳入强化学习奖励
7.2 性能提升
- 效果提升:平均+10.9%的推理性能
- 效率提升:42.4%的推理开销降低
7.3 未来方向
- 自动化ID阈值调整
- 与其他推理增强方法的结合
- 在更大规模模型上的验证