概述
本文档总结 f-GRPO 框架和基于散度(divergence)的强化学习方法用于LLM对齐的核心研究。这包括两篇关键论文的理论框架、数学推导和关键洞察。
1. 数学框架:f-散度与强化学习
1.1 f-散度定义
f-散度是衡量两个概率分布差异的通用框架:
其中 是凸函数且满足 。
1.2 常见f-散度类型
| 名称 | 特性 | |
|---|---|---|
| KL散度 | 鼓励探索,覆盖所有模式 | |
| 反向KL | 避免mode-seeking,追求单一模式 | |
| Pearson | 对高概率差异惩罚更重 | |
| Hellinger | 对称性更好 | |
| Jensen-Shannon | 有界, | |
| Total Variation | $\frac{1}{2} | t-1 |
1.3 变分表示
f-散度可以通过变分形式表示:
其中 是 的凸共轭。
2. GRPO到f-GRPO的泛化
2.1 标准GRPO回顾
Group Relative Policy Optimization (GRPO) 通过以下方式简化PPO:
- 组采样:对同一prompt生成G个响应
- 标准化优势函数:
- 优势加权策略更新
2.2 f-GRPO的核心创新
f-GRPO 将f-散度框架引入RLVR,设置以下对齐/不对齐分布:
- 对齐分布 :高于平均奖励的响应
- 不对齐分布 :低于平均奖励的响应
奖励诱导分布定义:
2.3 f-GRPO损失函数
其中 是基于f-散度的加权函数:
2.4 f-GRPO vs GRPO的关键区别
| 特性 | GRPO | f-GRPO |
|---|---|---|
| 更新方向 | 标准化优势 | f-散度驱动的方向 |
| 低于均值响应 | 非零概率 | 可通过 控制 |
| 理论保证 | 单调奖励改进 | 散度估计 + 奖励改进 |
| 正则化 | KL散度 | 多种f-散度可选 |
3. DPO与RL目标的关系
3.1 DPO的隐式奖励
Direct Preference Optimization (DPO) 引入隐式奖励概念:
其中 是归一化常数。
3.2 KL-正则化对齐目标
标准RLHF/RLVR优化:
3.3 DPO损失的推导
DPO将RL问题转化为监督学习:
4. GIFT:统一GRPO、DPO和UNA
4.1 GIFT的核心思想
Group-Relative Implicit Fine-Tuning (GIFT) 提出了奖励匹配而非奖励最大化的范式转变。
4.2 GIFT的三个关键元素
- GRPO风格的组采样和归一化
- DPO风格的隐式奖励
- UNA风格的MSE损失
4.3 归一化消除β和Z(x)
GIFT通过组归一化消除了DPO中不可计算的partition函数和超参数:
隐式奖励:
归一化后:
4.4 GIFT损失函数
4.5 GIFT vs GRPO优势
| 方面 | GRPO | GIFT |
|---|---|---|
| 优化范式 | 奖励最大化 | 奖励匹配 |
| 优化景观 | 噪声大,需裁剪 | 平滑MSE |
| 超参数 | β需调优 | β不敏感 |
| 收敛速度 | 较慢 | 更快 |
| 过拟合 | 较严重 | 较轻 |
5. 理论保证
5.1 f-GRPO理论结果
散度估计:当 ,
对齐一致性:固定点策略满足:
平均奖励改进:在正则条件下,
5.2 GIFT的理论特性
GIFT的种群最优解精确重合GRPO/RLHF解族:
6. 实验结果
6.1 数学推理(RLVR)
| 方法 | GSM8K | MATH500 | AMC23 | AIME24 | AIME25 |
|---|---|---|---|---|---|
| Base | 21.39 | 25.66 | 27.34 | 5.63 | 2.29 |
| GRPO | 71.67 | 60.60 | 43.28 | 8.54 | 3.13 |
| Pearson | 72.59 | 61.70 | 45.16 | 10.42 | 3.33 |
| JS | 72.47 | 61.93 | 40.94 | 7.50 | 4.79 |
6.2 散度选择的影响
| 散度 | 特性 | 适用场景 |
|---|---|---|
| KL | 覆盖所有模式 | 平衡探索/利用 |
| Reverse KL | mode-seeking | 精确控制,保守更新 |
| Pearson | 惩罚高差异 | 快速收敛 |
| JS | 有界,稳定 | 对话质量 |
7. 代码实现
7.1 f-GRPO核心实现
#include <bits/stdc++.h>
using namespace std;
struct Response {
vector<int> tokens;
double reward;
double log_prob;
};
double compute_f_divergence(const vector<double>& ratios, const string& div_type) {
if (div_type == "kl") {
return accumulate(ratios.begin(), ratios.end(), 0.0,
[](double sum, double r) { return sum + r * log(r); });
} else if (div_type == "reverse_kl") {
return accumulate(ratios.begin(), ratios.end(), 0.0,
[](double sum, double r) { return sum - log(r); });
} else if (div_type == "pearson") {
return accumulate(ratios.begin(), ratios.end(), 0.0,
[](double sum, double r) { double d = r - 1.0; return sum + d * d; });
} else if (div_type == "hellinger") {
return accumulate(ratios.begin(), ratios.end(), 0.0,
[](double sum, double r) { double s = sqrt(r) - 1.0; return sum + s * s; });
}
return 0.0;
}
double fgrpo_loss(
Policy& policy,
const vector<Prompt>& prompts,
const vector<Response>& responses,
const vector<double>& advantages,
double beta,
const string& div_type
) {
int G = responses.size();
double total_loss = 0.0;
// 计算组统计量
double mean_adv = accumulate(advantages.begin(), advantages.end(), 0.0) / G;
double std_adv = 0.0;
for (double a : advantages) {
std_adv += (a - mean_adv) * (a - mean_adv);
}
std_adv = sqrt(std_adv / G) + 1e-8;
// 计算f-散度加权的优势
vector<double> weighted_adv(G);
for (int i = 0; i < G; i++) {
double normalized_adv = (advantages[i] - mean_adv) / std_adv;
if (normalized_adv > 0) {
// 使用选定的f-散度
double r_theta = exp(responses[i].log_prob - old_log_prob[i]);
weighted_adv[i] = normalized_adv * compute_f_divergence({r_theta}, div_type);
} else {
// 低于均值的响应使用不同的处理
double r_theta = exp(responses[i].log_prob - old_log_prob[i]);
weighted_adv[i] = normalized_adv * compute_conjugate_divergence({r_theta}, div_type);
}
}
// 计算策略损失
for (int i = 0; i < G; i++) {
double ratio = exp(responses[i].log_prob - old_log_prob[i]);
double clipped = clamp(ratio, 1 - epsilon, 1 + epsilon);
double loss = -min(ratio * weighted_adv[i], clipped * weighted_adv[i]);
total_loss += loss;
}
// KL惩罚
double kl_penalty = beta * kl_divergence(policy, ref_policy);
return total_loss / G + kl_penalty;
}7.2 GIFT算法流程
struct GIFTConfig {
double lr = 3e-6;
int group_size = 16;
double beta = 1.0; // 对beta不敏感
};
double gift_loss(
Policy& policy,
Policy& ref_policy,
RewardModel& reward_model,
const Prompt& prompt,
const GIFTConfig& config
) {
// 1. 生成N个响应
vector<Response> responses;
for (int i = 0; i < config.group_size; i++) {
responses.push_back(policy.sample(prompt));
}
// 2. 计算显式奖励
vector<double> explicit_rewards;
for (const auto& r : responses) {
explicit_rewards.push_back(reward_model.evaluate(prompt, r));
}
// 3. 计算隐式奖励(DPO风格)
vector<double> implicit_rewards;
for (const auto& r : responses) {
double log_ratio = r.log_prob - ref_policy.log_prob(r);
implicit_rewards.push_back(log_ratio);
}
// 4. 组归一化(消除β和Z(x))
auto normalize = [](vector<double>& x) {
double mean = accumulate(x.begin(), x.end(), 0.0) / x.size();
double std = 0.0;
for (double v : x) std += (v - mean) * (v - mean);
std = sqrt(std / x.size()) + 1e-8;
for (double& v : x) v = (v - mean) / std;
};
normalize(explicit_rewards);
normalize(implicit_rewards);
// 5. MSE损失(奖励匹配)
double mse_loss = 0.0;
for (int i = 0; i < config.group_size; i++) {
double diff = explicit_rewards[i] - implicit_rewards[i];
mse_loss += diff * diff;
}
return mse_loss / config.group_size;
}8. 关键洞察
8.1 散度框架的统一性
- 偏好对齐:FDO目标直接估计chosen/rejected分布间的散度
- RLVR:通过奖励诱导对齐/不对齐分布,f-GRPO估计其散度
- 混合对齐:-HAL同时利用偏好和奖励信号
8.2 从奖励最大化到奖励匹配
- GRPO/DPO:直接最大化奖励
- GIFT:匹配隐式和显式奖励函数
- 优势:更稳定的优化,更少的超参数调优
8.3 组归一化的威力
组归一化可以消除:
- DPO中不可计算的partition函数
- β系数的影响
- 序列长度偏差