概述

过程奖励模型(Process Reward Model, PRM)是一种用于评估多步推理过程中每个步骤质量的模型。与仅评估最终结果的结果奖励模型(Outcome Reward Model, ORM)不同,PRM能够对推理的中间步骤进行评分,从而实现更细粒度的过程监督和验证。1

核心价值:PRM使得对复杂推理任务的”对错”判断扩展为对推理”质量”的判断。

背景:ORM vs PRM

结果奖励模型(ORM)

ORM仅评估最终答案的正确性:

# ORM:只看结果
def orm_reward(response, ground_truth):
    """
    ORM奖励:仅基于最终答案是否正确
    """
    predicted_answer = extract_answer(response)
    correct = predicted_answer == ground_truth
    
    return 1.0 if correct else 0.0
 
# 问题:无法区分"巧合蒙对"和"正确推理"

过程奖励模型(PRM)

PRM评估每个推理步骤:

# PRM:评估每一步
def prm_reward(response_steps, ground_truth):
    """
    PRM奖励:评估每个步骤的质量
    """
    step_rewards = []
    
    for i, step in enumerate(response_steps):
        # 评估当前步骤的逻辑正确性
        step_quality = evaluate_step_quality(
            step, 
            previous_steps=response_steps[:i]
        )
        step_rewards.append(step_quality)
    
    # 最终答案正确性(作为辅助奖励)
    final_correct = is_final_answer_correct(response_steps[-1], ground_truth)
    
    return step_rewards, final_correct

对比总结

特性ORMPRM
评估粒度最终答案每个推理步骤
训练信号稀疏(仅最后一步)密集(每步都有)
正确推理全部正奖励逐步累积正奖励
错误推理全部负奖励早期负奖励 + 纠正正奖励
标注成本低(仅需最终标签)高(需逐步标注)
可扩展性容易困难

PRM的设计

1. 输入格式

PRM接收推理过程(问题 + 步骤序列)并输出每步的分数:

class ProcessRewardModel(nn.Module):
    """
    PRM架构
    """
    def __init__(self, encoder, hidden_dim=1024):
        super().__init__()
        self.encoder = encoder  # LLM backbone
        self.step_scorer = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Linear(hidden_dim // 2, 1)
        )
    
    def forward(self, question, steps):
        """
        Args:
            question: 问题文本
            steps: 推理步骤列表
        Returns:
            step_scores: 每个步骤的分数 (n_steps,)
        """
        scores = []
        
        # 累积上下文
        context = question
        
        for step in steps:
            # 构建当前步骤的输入
            step_input = context + "\n" + step
            
            # 编码
            hidden = self.encoder(step_input)
            
            # 打分
            score = self.step_scorer(hidden)
            scores.append(score)
            
            # 更新上下文
            context += "\n" + step
        
        return torch.stack(scores).squeeze(-1)

2. 训练数据格式

# PRM训练数据格式
prm_training_data = [
    {
        "question": "计算 (15 + 23) × 2 - 17",
        "steps": [
            {"text": "先算括号内:15 + 23 = 38", "label": 1.0},  # 正确
            {"text": "然后:38 × 2 = 76", "label": 1.0},         # 正确
            {"text": "最后:76 - 17 = 59", "label": 1.0}         # 正确
        ],
        "final_answer": "59"
    },
    {
        "question": "求 x:3x + 5 = 20",
        "steps": [
            {"text": "移项:3x = 20 + 5 = 25", "label": -1.0},   # 错误!应该是减5
            {"text": "两边除以3:x = 25/3", "label": -0.5},      # 错误传播
            {"text": "所以 x ≈ 8.33", "label": 0.0}             # 无效
        ],
        "final_answer": "5"  # 正确答案
    }
]

3. 奖励设计

def compute_prm_reward(step_scores, final_correct):
    """
    结合PRM分数和最终正确性计算奖励
    """
    # 步骤分数(取平均或加权和)
    step_reward = step_scores.mean()
    
    # 最终正确性奖励(放大信号)
    final_reward = 2.0 if final_correct else -1.0
    
    # 组合奖励
    total_reward = 0.4 * step_reward + 0.6 * final_reward
    
    return total_reward
 
 
def compute_rl_loss(log_probs, advantages):
    """
    PRM用于RL训练的损失计算
    """
    # 策略梯度损失
    policy_loss = -(log_probs * advantages).mean()
    
    # 熵正则(鼓励探索)
    entropy_loss = entropy(log_probs).mean()
    
    return policy_loss - 0.01 * entropy_loss

PRM的应用场景

1. 推理验证

PRM可以验证LLM生成的推理过程:

def verify_reasoning(model, prm, question, response):
    """
    使用PRM验证推理过程
    """
    # 解析响应中的步骤
    steps = parse_steps(response)
    
    # 获取每步分数
    step_scores = prm(question, steps)
    
    # 分析结果
    verification = {
        "step_scores": step_scores.tolist(),
        "avg_score": step_scores.mean().item(),
        "min_score": step_scores.min().item(),
        "flags": identify_problematic_steps(step_scores)
    }
    
    return verification
 
 
def identify_problematic_steps(scores, threshold=0.0):
    """识别可能有问题的步骤"""
    flags = []
    
    for i, score in enumerate(scores):
        if score < threshold:
            flags.append({
                "step": i,
                "score": score.item(),
                "severity": "high" if score < -0.5 else "medium"
            })
    
    return flags

2. 搜索引导

PRM可以引导束搜索或MCTS探索更好的推理路径:

class PRMGuidedSearch:
    """
    使用PRM引导搜索
    """
    
    def __init__(self, generator, prm):
        self.generator = generator
        self.prm = prm
    
    def search(self, question, beam_width=4, max_steps=10):
        """
        束搜索 + PRM引导
        """
        # 初始化:问题作为根节点
        beams = [{"steps": [], "score": 0.0}]
        
        for step_num in range(max_steps):
            # 扩展所有beam
            candidates = []
            
            for beam in beams:
                # 生成下一组候选步骤
                candidates_step = self.generator.generate_next_steps(
                    question, beam["steps"]
                )
                
                # 用PRM评估
                for step in candidates_step:
                    steps_with_new = beam["steps"] + [step]
                    score = self.prm(question, steps_with_new)[-1]
                    candidates.append({
                        "steps": steps_with_new,
                        "score": beam["score"] + score
                    })
            
            # 剪枝:保留top-k
            candidates.sort(key=lambda x: x["score"], reverse=True)
            beams = candidates[:beam_width]
            
            # 早期停止:如果最佳beam的分数很低
            if beams[0]["score"] < -1.0:
                break
        
        # 返回最佳beam
        return beams[0]

3. 过程监督训练

PRM可以提供过程监督信号训练LLM:

def process_supervised_training(llm, prm, training_data):
    """
    使用PRM进行过程监督训练
    """
    optimizer = torch.optim.AdamW(llm.parameters(), lr=1e-5)
    
    for batch in training_data:
        # 前向传播
        question = batch["question"]
        steps = batch["steps"]
        
        # 生成响应
        response = llm.generate(question, steps)
        
        # 获取PRM评分
        step_scores = prm(question, steps)
        
        # 计算优势
        advantages = compute_advantages(step_scores)
        
        # 计算策略梯度损失
        log_probs = llm.get_log_probs(steps)
        loss = -(log_probs * advantages).mean()
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

PRM的训练方法

1. 人工标注

最直接但成本最高的方法:

# PRM标注界面设计
def prm_annotation_guidelines():
    return """
    PRM步骤标注指南:
    
    分数范围:-1 到 1
    
    1.0 - 完全正确,对最终答案有贡献
    0.5 - 部分正确,但有瑕疵
    0.0 - 中性步骤
    -0.5 - 有错误但后续可能纠正
    -1.0 - 严重错误,导致推理失败
    
    标注注意事项:
    1. 关注逻辑正确性,而非表面形式
    2. 考虑步骤在整体推理中的作用
    3. 早期错误比晚期错误更严重
    4. 如果步骤可被后续步骤纠正,可适当宽容
    """

2. 过程监督的PPO

def ppo_with_prm(policy, ref_policy, prm, data):
    """
    使用PRM的PPO训练
    """
    for batch in data:
        # 生成响应
        responses = policy.generate(batch["questions"])
        
        # 用PRM评估每步
        for response in responses:
            steps = parse_steps(response)
            step_scores = prm(batch["question"], steps)
            
            # 计算优势(使用GAE)
            advantages = compute_gae(step_scores, gamma=0.99, lam=0.95)
            
            # PPO更新
            log_probs = policy.get_log_probs(responses)
            ref_log_probs = ref_policy.get_log_probs(responses)
            
            ratio = torch.exp(log_probs - ref_log_probs)
            clipped = torch.clamp(ratio, 0.8, 1.2)
            
            loss = -torch.min(ratio * advantages, clipped * advantages).mean()
            loss.backward()

3. Process-Supervised RLHF

OpenAI在MathShepherd工作中提出的方法:

# Process-Supervised RLHF
def process_rlhf_training(policy, prm, orm, data):
    """
    结合PRM和ORM的过程监督训练
    """
    for batch in data:
        # 生成多个候选响应
        candidates = [policy.generate(q) for q in batch["questions"]]
        
        # PRM评分
        prm_scores = [prm(q, c) for q, c in zip(batch["questions"], candidates)]
        
        # ORM评分
        orm_scores = [orm(q, c) for q, c in zip(batch["questions"], candidates)]
        
        # 组合奖励
        combined_rewards = [
            0.3 * prm_s + 0.7 * orm_s 
            for prm_s, orm_s in zip(prm_scores, orm_scores)
        ]
        
        # PPO更新
        update_policy(policy, candidates, combined_rewards)

PRM的挑战与解决方案

1. 标注成本高

问题:PRM需要逐步骤标注,比ORM成本高10-100倍。

解决方案

方法描述效果
弱监督使用ORM信号训练初始PRM降低成本
自动标注用LLM生成步骤标注效率提升
对比学习偏好数据训练隐式PRM无需显式标注
过程蒸馏从强模型蒸馏PRM知识迁移

2. 步骤粒度定义

问题:如何定义”一步”?

# 粒度选择策略
def define_step_granularity(task):
    """
    根据任务类型选择步骤粒度
    """
    if task == "math_proof":
        # 数学证明:按推理链划分
        return "logical_unit"
    elif task == "code_generation":
        # 代码生成:按函数/语句块划分
        return "function_level"
    elif task == "qa":
        # 问答:按语义单元划分
        return "semantic_unit"
    else:
        return "auto_granularity"  # 模型自适应

3. 错误传播

问题:早期步骤错误导致后续步骤评分不公平。

def handle_error_propagation(prm, steps, strategy="masked"):
    """
    处理错误传播问题
    """
    if strategy == "masked":
        # 掩盖已错误步骤的影响
        corrected_steps = []
        has_error = False
        
        for step in steps:
            score = prm(step)
            
            if score < 0 and not has_error:
                has_error = True
                corrected_steps.append(0.0)  # 错误步骤得0分
            elif has_error:
                # 忽略错误步骤后的评分
                corrected_steps.append(0.0)
            else:
                corrected_steps.append(score)
        
        return corrected_steps
    
    elif strategy == "correction_aware":
        # 考虑自我纠正的评分
        ...

评估基准

Process Benchmarks

基准描述规模
PRM800K数学步骤标注数据800K步骤
ProcessBench8个数学竞赛问题的过程标注12K步骤
MathVista视觉数学推理6K问题
MATH数学竞赛题12K问题

评估指标

def evaluate_prm(prm, test_data):
    """
    评估PRM性能
    """
    results = {
        "step_accuracy": [],  # 步骤正确性预测准确率
        "step_auc": [],       # 步骤正确性AUC
        "verification_accuracy": [],  # 验证最终答案的能力
    }
    
    for sample in test_data:
        # 预测步骤分数
        pred_scores = prm(sample["question"], sample["steps"])
        
        # 与标注比较
        results["step_accuracy"].append(
            step_accuracy(pred_scores, sample["labels"])
        )
        results["step_auc"].append(
            step_auc(pred_scores, sample["labels"])
        )
        
        # 验证能力
        results["verification_accuracy"].append(
            verify_final_answer(pred_scores, sample["final_answer"])
        )
    
    return {k: np.mean(v) for k, v in results.items()}

与其他方法的关系

PRM vs ORM

                ┌─────────────────────────────────────┐
                │           奖励模型对比               │
                ├─────────────────────────────────────┤
                │                                     │
                │   问题:求 x + y = 10, x = 4      │
                │                                     │
                │   ORM: 回答 y = 6 → ✓ 1.0分        │
                │       回答 y = 7 → ✗ 0.0分        │
                │                                     │
                │   PRM: 步骤1: 10 - 4 = 6 → 1.0    │
                │       步骤2: 所以 y = 6 → 1.0      │
                │                                     │
                │       步骤1: 10 + 4 = 14 → -0.5   │
                │       步骤2: 所以 y = 14 → -0.8   │
                │                                     │
                └─────────────────────────────────────┘

PRM + MCTS

PRM和MCTS是互补的:

# PRM引导的MCTS
def prm_mcts_search(problem, generator, prm):
    """
    PRM提供每步的价值估计,引导MCTS探索
    """
    tree = MCTSTree()
    
    # 根节点
    root = tree.add_node(state=problem, depth=0)
    
    while within_budget():
        # 选择:基于UCB,使用PRM作为先验
        node = tree.select(root, prior_fn=lambda n: prm.score(n))
        
        # 扩展
        children = generator.expand(node)
        
        # 评估:用PRM评估新节点
        for child in children:
            prm_score = prm(child)
            tree.backup(child, prm_score)
    
    return tree.best_path(root)

参考


相关主题

Footnotes

  1. Lightman et al. “Let’s Verify Step by Step”. arXiv:2305.20050, 2023.