概述

本文档总结 f-GRPO 框架和基于散度(divergence)的强化学习方法用于LLM对齐的核心研究。这包括两篇关键论文的理论框架、数学推导和关键洞察。


1. 数学框架:f-散度与强化学习

1.1 f-散度定义

f-散度是衡量两个概率分布差异的通用框架:

其中 是凸函数且满足

1.2 常见f-散度类型

名称特性
KL散度鼓励探索,覆盖所有模式
反向KL避免mode-seeking,追求单一模式
Pearson 对高概率差异惩罚更重
Hellinger对称性更好
Jensen-Shannon有界,
Total Variation$\frac{1}{2}t-1

1.3 变分表示

f-散度可以通过变分形式表示:

其中 的凸共轭。


2. GRPO到f-GRPO的泛化

2.1 标准GRPO回顾

Group Relative Policy Optimization (GRPO) 通过以下方式简化PPO:

  1. 组采样:对同一prompt生成G个响应
  2. 标准化优势函数
  3. 优势加权策略更新

2.2 f-GRPO的核心创新

f-GRPO 将f-散度框架引入RLVR,设置以下对齐/不对齐分布:

  • 对齐分布 :高于平均奖励的响应
  • 不对齐分布 :低于平均奖励的响应

奖励诱导分布定义

2.3 f-GRPO损失函数

其中 是基于f-散度的加权函数:

2.4 f-GRPO vs GRPO的关键区别

特性GRPOf-GRPO
更新方向标准化优势f-散度驱动的方向
低于均值响应非零概率可通过 控制
理论保证单调奖励改进散度估计 + 奖励改进
正则化KL散度多种f-散度可选

3. DPO与RL目标的关系

3.1 DPO的隐式奖励

Direct Preference Optimization (DPO) 引入隐式奖励概念:

其中 是归一化常数。

3.2 KL-正则化对齐目标

标准RLHF/RLVR优化:

3.3 DPO损失的推导

DPO将RL问题转化为监督学习:


4. GIFT:统一GRPO、DPO和UNA

4.1 GIFT的核心思想

Group-Relative Implicit Fine-Tuning (GIFT) 提出了奖励匹配而非奖励最大化的范式转变。

4.2 GIFT的三个关键元素

  1. GRPO风格的组采样和归一化
  2. DPO风格的隐式奖励
  3. UNA风格的MSE损失

4.3 归一化消除β和Z(x)

GIFT通过组归一化消除了DPO中不可计算的partition函数和超参数:

隐式奖励

归一化后

4.4 GIFT损失函数

4.5 GIFT vs GRPO优势

方面GRPOGIFT
优化范式奖励最大化奖励匹配
优化景观噪声大,需裁剪平滑MSE
超参数β需调优β不敏感
收敛速度较慢更快
过拟合较严重较轻

5. 理论保证

5.1 f-GRPO理论结果

散度估计:当

对齐一致性:固定点策略满足:

平均奖励改进:在正则条件下,

5.2 GIFT的理论特性

GIFT的种群最优解精确重合GRPO/RLHF解族:


6. 实验结果

6.1 数学推理(RLVR)

方法GSM8KMATH500AMC23AIME24AIME25
Base21.3925.6627.345.632.29
GRPO71.6760.6043.288.543.13
Pearson72.5961.7045.1610.423.33
JS72.4761.9340.947.504.79

6.2 散度选择的影响

散度特性适用场景
KL覆盖所有模式平衡探索/利用
Reverse KLmode-seeking精确控制,保守更新
Pearson惩罚高差异快速收敛
JS有界,稳定对话质量

7. 代码实现

7.1 f-GRPO核心实现

#include <bits/stdc++.h>
using namespace std;
 
struct Response {
    vector<int> tokens;
    double reward;
    double log_prob;
};
 
double compute_f_divergence(const vector<double>& ratios, const string& div_type) {
    if (div_type == "kl") {
        return accumulate(ratios.begin(), ratios.end(), 0.0,
            [](double sum, double r) { return sum + r * log(r); });
    } else if (div_type == "reverse_kl") {
        return accumulate(ratios.begin(), ratios.end(), 0.0,
            [](double sum, double r) { return sum - log(r); });
    } else if (div_type == "pearson") {
        return accumulate(ratios.begin(), ratios.end(), 0.0,
            [](double sum, double r) { double d = r - 1.0; return sum + d * d; });
    } else if (div_type == "hellinger") {
        return accumulate(ratios.begin(), ratios.end(), 0.0,
            [](double sum, double r) { double s = sqrt(r) - 1.0; return sum + s * s; });
    }
    return 0.0;
}
 
double fgrpo_loss(
    Policy& policy,
    const vector<Prompt>& prompts,
    const vector<Response>& responses,
    const vector<double>& advantages,
    double beta,
    const string& div_type
) {
    int G = responses.size();
    double total_loss = 0.0;
    
    // 计算组统计量
    double mean_adv = accumulate(advantages.begin(), advantages.end(), 0.0) / G;
    double std_adv = 0.0;
    for (double a : advantages) {
        std_adv += (a - mean_adv) * (a - mean_adv);
    }
    std_adv = sqrt(std_adv / G) + 1e-8;
    
    // 计算f-散度加权的优势
    vector<double> weighted_adv(G);
    for (int i = 0; i < G; i++) {
        double normalized_adv = (advantages[i] - mean_adv) / std_adv;
        
        if (normalized_adv > 0) {
            // 使用选定的f-散度
            double r_theta = exp(responses[i].log_prob - old_log_prob[i]);
            weighted_adv[i] = normalized_adv * compute_f_divergence({r_theta}, div_type);
        } else {
            // 低于均值的响应使用不同的处理
            double r_theta = exp(responses[i].log_prob - old_log_prob[i]);
            weighted_adv[i] = normalized_adv * compute_conjugate_divergence({r_theta}, div_type);
        }
    }
    
    // 计算策略损失
    for (int i = 0; i < G; i++) {
        double ratio = exp(responses[i].log_prob - old_log_prob[i]);
        double clipped = clamp(ratio, 1 - epsilon, 1 + epsilon);
        double loss = -min(ratio * weighted_adv[i], clipped * weighted_adv[i]);
        total_loss += loss;
    }
    
    // KL惩罚
    double kl_penalty = beta * kl_divergence(policy, ref_policy);
    
    return total_loss / G + kl_penalty;
}

7.2 GIFT算法流程

struct GIFTConfig {
    double lr = 3e-6;
    int group_size = 16;
    double beta = 1.0;  // 对beta不敏感
};
 
double gift_loss(
    Policy& policy,
    Policy& ref_policy,
    RewardModel& reward_model,
    const Prompt& prompt,
    const GIFTConfig& config
) {
    // 1. 生成N个响应
    vector<Response> responses;
    for (int i = 0; i < config.group_size; i++) {
        responses.push_back(policy.sample(prompt));
    }
    
    // 2. 计算显式奖励
    vector<double> explicit_rewards;
    for (const auto& r : responses) {
        explicit_rewards.push_back(reward_model.evaluate(prompt, r));
    }
    
    // 3. 计算隐式奖励(DPO风格)
    vector<double> implicit_rewards;
    for (const auto& r : responses) {
        double log_ratio = r.log_prob - ref_policy.log_prob(r);
        implicit_rewards.push_back(log_ratio);
    }
    
    // 4. 组归一化(消除β和Z(x))
    auto normalize = [](vector<double>& x) {
        double mean = accumulate(x.begin(), x.end(), 0.0) / x.size();
        double std = 0.0;
        for (double v : x) std += (v - mean) * (v - mean);
        std = sqrt(std / x.size()) + 1e-8;
        for (double& v : x) v = (v - mean) / std;
    };
    
    normalize(explicit_rewards);
    normalize(implicit_rewards);
    
    // 5. MSE损失(奖励匹配)
    double mse_loss = 0.0;
    for (int i = 0; i < config.group_size; i++) {
        double diff = explicit_rewards[i] - implicit_rewards[i];
        mse_loss += diff * diff;
    }
    
    return mse_loss / config.group_size;
}

8. 关键洞察

8.1 散度框架的统一性

  1. 偏好对齐:FDO目标直接估计chosen/rejected分布间的散度
  2. RLVR:通过奖励诱导对齐/不对齐分布,f-GRPO估计其散度
  3. 混合对齐-HAL同时利用偏好和奖励信号

8.2 从奖励最大化到奖励匹配

  • GRPO/DPO:直接最大化奖励
  • GIFT:匹配隐式和显式奖励函数
  • 优势:更稳定的优化,更少的超参数调优

8.3 组归一化的威力

组归一化可以消除:

  1. DPO中不可计算的partition函数
  2. β系数的影响
  3. 序列长度偏差

9. 参考