大型推理模型的自我改进技术 (HSIR)

1. 问题背景

1.1 自我改进的兴起

自我改进训练(Self-Improvement Training)使大型推理模型(Large Reasoning Models, LRMs)能够:

  • 通过自生成推理轨迹作为训练数据
  • 无需外部监督即可提升自身能力
  • 在推理任务上取得显著进步

1.2 现有方法的局限性

尽管自我改进展现出潜力,研究发现其存在严重问题:

问题表现影响
数据不平衡大多数训练样本简单,挑战性样本稀缺复杂任务表现不佳
过度思考包含冗余推理步骤的样本被用于训练模型崩溃风险

1.3 问题分析

1.3.1 数据不平衡

在自我生成的数据中:

这导致模型在简单任务上过度拟合,而在复杂任务上泛化能力不足。

1.3.2 过度思考(Overthinking)

许多样本包含冗余的推理步骤:

  • 早期正确:在少量步骤后已得到正确答案
  • 继续推理:模型继续生成不必要的步骤
  • 错误累积:冗余步骤可能引入错误
  • 训练污染:这些样本污染训练数据

2. HSIR方法详解

2.1 核心思想

HSIR(Harnessing Self-Improvement in large Reasoning models)通过两个简单而有效的策略解决上述问题:

  1. 验证后退出采样策略:缓解数据不平衡
  2. 内在多样性评分:量化并过滤过度思考样本

2.2 验证后退出采样策略

2.2.1 动机

传统采样方法在所有查询上使用相同的采样数量,无法有效收集复杂查询的多样化解。

2.2.2 方法设计

Verify-then-Exit采样

对于每个查询Q:
    1. 初始采样N个解答
    
    2. 验证每个解答的正确性
       for 解答 in sampled_solutions:
           if is_correct(解答):
               添加到训练集
           else:
               继续采样...
    
    3. 重复直到收集到足够多的正确解答
       或者达到最大采样次数

关键洞察:为复杂查询投入更多采样预算,以收集更多正确答案。

2.2.3 形式化

为查询 次采样解答集:

采样策略优化:

2.3 内在多样性评分

2.3.1 过度思考的量化

过度思考样本的共同特征:

  • 推理轨迹过长
  • 包含重复或循环的推理步骤
  • 关键决策点后仍有大量冗余

2.3.2 ID分数定义

内在多样性(Intrinsic Diversity, ID)评分

其中:

  • :推理中导致答案变化的关键节点
  • :完整推理轨迹长度
  • :与多数投票答案的一致程度

2.3.3 过度思考过滤

设定阈值 进行过滤:

低于阈值的样本被判定为过度思考,不用于训练。

3. H-GRPO算法

3.1 GRPO基础回顾

GRPO(Group Relative Policy Optimization)是一种强化学习算法,在推理训练中广泛使用。

3.2 H-GRPO增强

H-GRPO在标准GRPO基础上引入内在多样性作为外部奖励

其中:

  • :任务原始奖励(正确性)
  • :多样性权重系数
  • :内在多样性分数

3.3 训练目标

4. 实验结果

4.1 主要结果

在多个推理基准上的性能提升:

基准基线GRPOHSIR提升
MATH72.376.882.1+10.9%
AIME45.248.152.3+7.1%
ARC-C89.591.293.8+4.3%
平均---+10.9%

4.2 推理效率

推理开销显著降低:

模型推理步数相对开销
基线1281.0×
标准GRPO1561.22×
HSIR900.70×

推理效率提升:42.4%

4.3 消融实验

4.3.1 Verify-then-Exit的效果

方法复杂查询准确率简单查询准确率
标准采样45.2%91.5%
Verify-then-Exit58.7%90.8%

4.3.2 内在多样性的效果

方法过度思考比例训练稳定性
无过滤38.5%
ID过滤12.1%

5. PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional, Dict
from dataclasses import dataclass
 
@dataclass
class ReasoningSample:
    """A reasoning sample with its metadata."""
    query: str
    solution: str
    answer: str
    is_correct: bool
    reasoning_steps: List[str]
    key_decision_points: int
    id_score: float
 
 
class VerifyThenExitSampler:
    """
    Verify-then-Exit sampling strategy.
    Collects more correct solutions for difficult queries.
    """
    def __init__(
        self,
        model,
        verifier,
        max_samples_per_query: int = 32,
        min_correct_ratio: float = 0.3,
        early_exit_threshold: int = 8
    ):
        self.model = model
        self.verifier = verifier
        self.max_samples = max_samples_per_query
        self.min_correct_ratio = min_correct_ratio
        self.early_exit_threshold = early_exit_threshold
        
    def sample(
        self,
        query: str,
        difficulty_hint: Optional[str] = None
    ) -> List[ReasoningSample]:
        """
        Sample solutions with verify-then-exit strategy.
        """
        samples = []
        correct_count = 0
        target_correct = int(self.max_samples * self.min_correct_ratio)
        
        # Adaptive sampling based on difficulty
        n_samples = self.max_samples
        if difficulty_hint == "hard":
            n_samples = int(self.max_samples * 1.5)
        elif difficulty_hint == "easy":
            n_samples = int(self.max_samples * 0.7)
        
        for i in range(n_samples):
            # Generate solution
            solution, answer, steps = self.model.generate(
                query, 
                return_reasoning=True
            )
            
            # Verify correctness
            is_correct = self.verifier.check(query, solution, answer)
            
            # Calculate ID score
            id_score = self._calculate_id_score(solution, steps)
            
            sample = ReasoningSample(
                query=query,
                solution=solution,
                answer=answer,
                is_correct=is_correct,
                reasoning_steps=steps,
                key_decision_points=self._count_key_decisions(steps),
                id_score=id_score
            )
            samples.append(sample)
            
            if is_correct:
                correct_count += 1
                
            # Early exit if we have enough correct samples
            if correct_count >= target_correct:
                break
                
        return samples
    
    def _calculate_id_score(
        self,
        solution: str,
        steps: List[str]
    ) -> float:
        """
        Calculate Intrinsic Diversity score.
        """
        total_steps = len(steps)
        if total_steps == 0:
            return 0.0
            
        # Count key decision points
        key_points = self._count_key_decisions(steps)
        
        # Calculate answer consistency (simplified)
        # In practice, compare with majority vote
        answer_consistency = 1.0  # Placeholder
        
        id_score = (key_points / total_steps) * answer_consistency
        return id_score
    
    def _count_key_decisions(self, steps: List[str]) -> int:
        """
        Count key decision points in reasoning steps.
        """
        # Key decision indicators
        indicators = [
            "therefore", "so", "thus",
            "conclude", "since", "because",
            "if and only if", "equivalent to",
            "the answer is", "final answer"
        ]
        
        count = 0
        for step in steps:
            step_lower = step.lower()
            for indicator in indicators:
                if indicator in step_lower:
                    count += 1
                    break
                    
        return max(count, 1)  # At least 1 decision point
 
 
class H_GRPO:
    """
    H-GRPO: Enhanced GRPO with Intrinsic Diversity reward.
    """
    def __init__(
        self,
        model,
        id_weight: float = 0.5,
        id_threshold: float = 0.3
    ):
        self.model = model
        self.id_weight = id_weight
        self.id_threshold = id_threshold
        
    def compute_reward(
        self,
        sample: ReasoningSample,
        baseline_samples: List[ReasoningSample]
    ) -> Tuple[float, bool]:
        """
        Compute combined reward with ID filtering.
        
        Returns:
            reward: Combined reward value
            keep: Whether to keep sample for training
        """
        # Task reward
        task_reward = 1.0 if sample.is_correct else 0.0
        
        # ID reward
        id_reward = self.id_weight * sample.id_score
        
        # Combined reward
        combined_reward = task_reward + id_reward
        
        # Filtering decision
        keep = sample.id_score > self.id_threshold
        
        return combined_reward, keep
    
    def update(
        self,
        query: str,
        samples: List[ReasoningSample],
        old_log_probs: torch.Tensor,
        clip_eps: float = 0.2,
        lr: float = 1e-5
    ):
        """
        Update policy using H-GRPO.
        """
        # Compute rewards and filter samples
        filtered_samples = []
        rewards = []
        
        for sample in samples:
            reward, keep = self.compute_reward(sample, samples)
            if keep:
                filtered_samples.append(sample)
                rewards.append(reward)
        
        if len(filtered_samples) == 0:
            return  # No valid samples to train on
        
        rewards = torch.tensor(rewards, device=old_log_probs.device)
        
        # Normalize rewards
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
        
        # Compute policy gradient loss
        new_log_probs = self.model.get_log_probs(query, filtered_samples)
        
        # GRPO-style importance sampling
        ratio = torch.exp(new_log_probs - old_log_probs)
        clipped_ratio = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps)
        
        # PPO-style objective with clipping
        loss = -torch.min(ratio * rewards, clipped_ratio * rewards).mean()
        
        # Update model
        self.model.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        
        for param in self.model.parameters():
            param.data -= lr * param.grad
            
        return loss.item()
 
 
class SelfImprovementTrainer:
    """
    Complete self-improvement training pipeline.
    """
    def __init__(
        self,
        model,
        verifier,
        optimizer,
        id_weight: float = 0.5,
        id_threshold: float = 0.3
    ):
        self.model = model
        self.sampler = VerifyThenExitSampler(model, verifier)
        self.hgrpo = H_GRPO(model, id_weight, id_threshold)
        self.optimizer = optimizer
        
    def train_step(self, batch_queries: List[str]) -> Dict[str, float]:
        """
        Single training step.
        """
        total_loss = 0.0
        total_samples = 0
        correct_before = 0
        correct_after = 0
        
        for query in batch_queries:
            # Sample with verify-then-exit
            samples = self.sampler.sample(query)
            
            if len(samples) == 0:
                continue
                
            # Count correct before training
            correct_before += sum(1 for s in samples if s.is_correct)
            
            # Update with H-GRPO
            with torch.no_grad():
                old_log_probs = self.model.get_log_probs(
                    query, samples
                )
                
            loss = self.hgrpo.update(query, samples, old_log_probs)
            
            if loss is not None:
                total_loss += loss
                total_samples += 1
                
            # Verify after training
            new_samples = self.sampler.sample(query, n_samples=4)
            correct_after += sum(1 for s in new_samples if s.is_correct)
        
        return {
            'avg_loss': total_loss / max(total_samples, 1),
            'correct_before': correct_before,
            'correct_after': correct_after,
            'improvement': correct_after - correct_before
        }

6. 关键洞察

6.1 数据不平衡的解决方案

核心思想:复杂查询需要更多采样预算

  • 自适应采样:根据查询难度调整采样数量
  • 验证优先:先验证正确性,再决定是否继续采样
  • 目标收集:优先收集正确答案用于训练

6.2 过度思考的识别与过滤

核心思想:冗长≠准确

  • 内在多样性评分:量化推理效率
  • 关键决策点:识别真正重要的推理步骤
  • 自适应阈值:根据任务调整过滤强度

6.3 强化学习的增强

核心思想:将效率纳入奖励

  • 任务奖励:正确性导向
  • 多样性奖励:简洁性导向
  • 综合目标:兼顾效果和效率

7. 总结与展望

7.1 核心贡献

  1. 问题诊断:系统分析自我改进中的数据不平衡和过度思考
  2. 解决方案:Verify-then-Exit采样 + 内在多样性评分
  3. 算法增强:H-GRPO将效率纳入强化学习奖励

7.2 性能提升

  • 效果提升:平均+10.9%的推理性能
  • 效率提升:42.4%的推理开销降低

7.3 未来方向

  • 自动化ID阈值调整
  • 与其他推理增强方法的结合
  • 在更大规模模型上的验证

参考资料