概述
过程奖励模型(Process Reward Model, PRM)是一种用于评估多步推理过程中每个步骤质量的模型。与仅评估最终结果的结果奖励模型(Outcome Reward Model, ORM)不同,PRM能够对推理的中间步骤进行评分,从而实现更细粒度的过程监督和验证。1
核心价值:PRM使得对复杂推理任务的”对错”判断扩展为对推理”质量”的判断。
背景:ORM vs PRM
结果奖励模型(ORM)
ORM仅评估最终答案的正确性:
# ORM:只看结果
def orm_reward(response, ground_truth):
"""
ORM奖励:仅基于最终答案是否正确
"""
predicted_answer = extract_answer(response)
correct = predicted_answer == ground_truth
return 1.0 if correct else 0.0
# 问题:无法区分"巧合蒙对"和"正确推理"过程奖励模型(PRM)
PRM评估每个推理步骤:
# PRM:评估每一步
def prm_reward(response_steps, ground_truth):
"""
PRM奖励:评估每个步骤的质量
"""
step_rewards = []
for i, step in enumerate(response_steps):
# 评估当前步骤的逻辑正确性
step_quality = evaluate_step_quality(
step,
previous_steps=response_steps[:i]
)
step_rewards.append(step_quality)
# 最终答案正确性(作为辅助奖励)
final_correct = is_final_answer_correct(response_steps[-1], ground_truth)
return step_rewards, final_correct对比总结
| 特性 | ORM | PRM |
|---|---|---|
| 评估粒度 | 最终答案 | 每个推理步骤 |
| 训练信号 | 稀疏(仅最后一步) | 密集(每步都有) |
| 正确推理 | 全部正奖励 | 逐步累积正奖励 |
| 错误推理 | 全部负奖励 | 早期负奖励 + 纠正正奖励 |
| 标注成本 | 低(仅需最终标签) | 高(需逐步标注) |
| 可扩展性 | 容易 | 困难 |
PRM的设计
1. 输入格式
PRM接收推理过程(问题 + 步骤序列)并输出每步的分数:
class ProcessRewardModel(nn.Module):
"""
PRM架构
"""
def __init__(self, encoder, hidden_dim=1024):
super().__init__()
self.encoder = encoder # LLM backbone
self.step_scorer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, question, steps):
"""
Args:
question: 问题文本
steps: 推理步骤列表
Returns:
step_scores: 每个步骤的分数 (n_steps,)
"""
scores = []
# 累积上下文
context = question
for step in steps:
# 构建当前步骤的输入
step_input = context + "\n" + step
# 编码
hidden = self.encoder(step_input)
# 打分
score = self.step_scorer(hidden)
scores.append(score)
# 更新上下文
context += "\n" + step
return torch.stack(scores).squeeze(-1)2. 训练数据格式
# PRM训练数据格式
prm_training_data = [
{
"question": "计算 (15 + 23) × 2 - 17",
"steps": [
{"text": "先算括号内:15 + 23 = 38", "label": 1.0}, # 正确
{"text": "然后:38 × 2 = 76", "label": 1.0}, # 正确
{"text": "最后:76 - 17 = 59", "label": 1.0} # 正确
],
"final_answer": "59"
},
{
"question": "求 x:3x + 5 = 20",
"steps": [
{"text": "移项:3x = 20 + 5 = 25", "label": -1.0}, # 错误!应该是减5
{"text": "两边除以3:x = 25/3", "label": -0.5}, # 错误传播
{"text": "所以 x ≈ 8.33", "label": 0.0} # 无效
],
"final_answer": "5" # 正确答案
}
]3. 奖励设计
def compute_prm_reward(step_scores, final_correct):
"""
结合PRM分数和最终正确性计算奖励
"""
# 步骤分数(取平均或加权和)
step_reward = step_scores.mean()
# 最终正确性奖励(放大信号)
final_reward = 2.0 if final_correct else -1.0
# 组合奖励
total_reward = 0.4 * step_reward + 0.6 * final_reward
return total_reward
def compute_rl_loss(log_probs, advantages):
"""
PRM用于RL训练的损失计算
"""
# 策略梯度损失
policy_loss = -(log_probs * advantages).mean()
# 熵正则(鼓励探索)
entropy_loss = entropy(log_probs).mean()
return policy_loss - 0.01 * entropy_lossPRM的应用场景
1. 推理验证
PRM可以验证LLM生成的推理过程:
def verify_reasoning(model, prm, question, response):
"""
使用PRM验证推理过程
"""
# 解析响应中的步骤
steps = parse_steps(response)
# 获取每步分数
step_scores = prm(question, steps)
# 分析结果
verification = {
"step_scores": step_scores.tolist(),
"avg_score": step_scores.mean().item(),
"min_score": step_scores.min().item(),
"flags": identify_problematic_steps(step_scores)
}
return verification
def identify_problematic_steps(scores, threshold=0.0):
"""识别可能有问题的步骤"""
flags = []
for i, score in enumerate(scores):
if score < threshold:
flags.append({
"step": i,
"score": score.item(),
"severity": "high" if score < -0.5 else "medium"
})
return flags2. 搜索引导
PRM可以引导束搜索或MCTS探索更好的推理路径:
class PRMGuidedSearch:
"""
使用PRM引导搜索
"""
def __init__(self, generator, prm):
self.generator = generator
self.prm = prm
def search(self, question, beam_width=4, max_steps=10):
"""
束搜索 + PRM引导
"""
# 初始化:问题作为根节点
beams = [{"steps": [], "score": 0.0}]
for step_num in range(max_steps):
# 扩展所有beam
candidates = []
for beam in beams:
# 生成下一组候选步骤
candidates_step = self.generator.generate_next_steps(
question, beam["steps"]
)
# 用PRM评估
for step in candidates_step:
steps_with_new = beam["steps"] + [step]
score = self.prm(question, steps_with_new)[-1]
candidates.append({
"steps": steps_with_new,
"score": beam["score"] + score
})
# 剪枝:保留top-k
candidates.sort(key=lambda x: x["score"], reverse=True)
beams = candidates[:beam_width]
# 早期停止:如果最佳beam的分数很低
if beams[0]["score"] < -1.0:
break
# 返回最佳beam
return beams[0]3. 过程监督训练
PRM可以提供过程监督信号训练LLM:
def process_supervised_training(llm, prm, training_data):
"""
使用PRM进行过程监督训练
"""
optimizer = torch.optim.AdamW(llm.parameters(), lr=1e-5)
for batch in training_data:
# 前向传播
question = batch["question"]
steps = batch["steps"]
# 生成响应
response = llm.generate(question, steps)
# 获取PRM评分
step_scores = prm(question, steps)
# 计算优势
advantages = compute_advantages(step_scores)
# 计算策略梯度损失
log_probs = llm.get_log_probs(steps)
loss = -(log_probs * advantages).mean()
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()PRM的训练方法
1. 人工标注
最直接但成本最高的方法:
# PRM标注界面设计
def prm_annotation_guidelines():
return """
PRM步骤标注指南:
分数范围:-1 到 1
1.0 - 完全正确,对最终答案有贡献
0.5 - 部分正确,但有瑕疵
0.0 - 中性步骤
-0.5 - 有错误但后续可能纠正
-1.0 - 严重错误,导致推理失败
标注注意事项:
1. 关注逻辑正确性,而非表面形式
2. 考虑步骤在整体推理中的作用
3. 早期错误比晚期错误更严重
4. 如果步骤可被后续步骤纠正,可适当宽容
"""2. 过程监督的PPO
def ppo_with_prm(policy, ref_policy, prm, data):
"""
使用PRM的PPO训练
"""
for batch in data:
# 生成响应
responses = policy.generate(batch["questions"])
# 用PRM评估每步
for response in responses:
steps = parse_steps(response)
step_scores = prm(batch["question"], steps)
# 计算优势(使用GAE)
advantages = compute_gae(step_scores, gamma=0.99, lam=0.95)
# PPO更新
log_probs = policy.get_log_probs(responses)
ref_log_probs = ref_policy.get_log_probs(responses)
ratio = torch.exp(log_probs - ref_log_probs)
clipped = torch.clamp(ratio, 0.8, 1.2)
loss = -torch.min(ratio * advantages, clipped * advantages).mean()
loss.backward()3. Process-Supervised RLHF
OpenAI在MathShepherd工作中提出的方法:
# Process-Supervised RLHF
def process_rlhf_training(policy, prm, orm, data):
"""
结合PRM和ORM的过程监督训练
"""
for batch in data:
# 生成多个候选响应
candidates = [policy.generate(q) for q in batch["questions"]]
# PRM评分
prm_scores = [prm(q, c) for q, c in zip(batch["questions"], candidates)]
# ORM评分
orm_scores = [orm(q, c) for q, c in zip(batch["questions"], candidates)]
# 组合奖励
combined_rewards = [
0.3 * prm_s + 0.7 * orm_s
for prm_s, orm_s in zip(prm_scores, orm_scores)
]
# PPO更新
update_policy(policy, candidates, combined_rewards)PRM的挑战与解决方案
1. 标注成本高
问题:PRM需要逐步骤标注,比ORM成本高10-100倍。
解决方案:
| 方法 | 描述 | 效果 |
|---|---|---|
| 弱监督 | 使用ORM信号训练初始PRM | 降低成本 |
| 自动标注 | 用LLM生成步骤标注 | 效率提升 |
| 对比学习 | 偏好数据训练隐式PRM | 无需显式标注 |
| 过程蒸馏 | 从强模型蒸馏PRM | 知识迁移 |
2. 步骤粒度定义
问题:如何定义”一步”?
# 粒度选择策略
def define_step_granularity(task):
"""
根据任务类型选择步骤粒度
"""
if task == "math_proof":
# 数学证明:按推理链划分
return "logical_unit"
elif task == "code_generation":
# 代码生成:按函数/语句块划分
return "function_level"
elif task == "qa":
# 问答:按语义单元划分
return "semantic_unit"
else:
return "auto_granularity" # 模型自适应3. 错误传播
问题:早期步骤错误导致后续步骤评分不公平。
def handle_error_propagation(prm, steps, strategy="masked"):
"""
处理错误传播问题
"""
if strategy == "masked":
# 掩盖已错误步骤的影响
corrected_steps = []
has_error = False
for step in steps:
score = prm(step)
if score < 0 and not has_error:
has_error = True
corrected_steps.append(0.0) # 错误步骤得0分
elif has_error:
# 忽略错误步骤后的评分
corrected_steps.append(0.0)
else:
corrected_steps.append(score)
return corrected_steps
elif strategy == "correction_aware":
# 考虑自我纠正的评分
...评估基准
Process Benchmarks
| 基准 | 描述 | 规模 |
|---|---|---|
| PRM800K | 数学步骤标注数据 | 800K步骤 |
| ProcessBench | 8个数学竞赛问题的过程标注 | 12K步骤 |
| MathVista | 视觉数学推理 | 6K问题 |
| MATH | 数学竞赛题 | 12K问题 |
评估指标
def evaluate_prm(prm, test_data):
"""
评估PRM性能
"""
results = {
"step_accuracy": [], # 步骤正确性预测准确率
"step_auc": [], # 步骤正确性AUC
"verification_accuracy": [], # 验证最终答案的能力
}
for sample in test_data:
# 预测步骤分数
pred_scores = prm(sample["question"], sample["steps"])
# 与标注比较
results["step_accuracy"].append(
step_accuracy(pred_scores, sample["labels"])
)
results["step_auc"].append(
step_auc(pred_scores, sample["labels"])
)
# 验证能力
results["verification_accuracy"].append(
verify_final_answer(pred_scores, sample["final_answer"])
)
return {k: np.mean(v) for k, v in results.items()}与其他方法的关系
PRM vs ORM
┌─────────────────────────────────────┐
│ 奖励模型对比 │
├─────────────────────────────────────┤
│ │
│ 问题:求 x + y = 10, x = 4 │
│ │
│ ORM: 回答 y = 6 → ✓ 1.0分 │
│ 回答 y = 7 → ✗ 0.0分 │
│ │
│ PRM: 步骤1: 10 - 4 = 6 → 1.0 │
│ 步骤2: 所以 y = 6 → 1.0 │
│ │
│ 步骤1: 10 + 4 = 14 → -0.5 │
│ 步骤2: 所以 y = 14 → -0.8 │
│ │
└─────────────────────────────────────┘
PRM + MCTS
PRM和MCTS是互补的:
# PRM引导的MCTS
def prm_mcts_search(problem, generator, prm):
"""
PRM提供每步的价值估计,引导MCTS探索
"""
tree = MCTSTree()
# 根节点
root = tree.add_node(state=problem, depth=0)
while within_budget():
# 选择:基于UCB,使用PRM作为先验
node = tree.select(root, prior_fn=lambda n: prm.score(n))
# 扩展
children = generator.expand(node)
# 评估:用PRM评估新节点
for child in children:
prm_score = prm(child)
tree.backup(child, prm_score)
return tree.best_path(root)参考
相关主题
Footnotes
-
Lightman et al. “Let’s Verify Step by Step”. arXiv:2305.20050, 2023. ↩