代理调优与线性对齐方法

大语言模型(LLM)的对齐问题是构建有用、安全AI系统的核心挑战。传统方法如RLHF(基于人类反馈的强化学习)虽然有效,但存在训练复杂、计算成本高等问题。近年来,推理时对齐(Decoding-time Alignment) 方法崭露头角,其中 Proxy TuningLinear Alignment 是两个代表性工作。12

1. Proxy Tuning 核心思想

1.1 问题背景

随着LLM规模的增长,直接微调大模型变得资源密集甚至不可能。当模型权重不可访问(如GPT-4等商业API)时,传统微调方法完全失效。

核心问题:如何在不修改大模型参数的情况下,实现类似于微调的效果?

1.2 方法概述

Proxy Tuning由Liu等人于2024年提出1,其核心思想是:

调优一个小模型,然后将小模型微调前后的预测差异应用到更大的目标模型上。

该方法仅需要访问模型的输出logits(词汇表上的概率分布),不需要访问模型内部权重。

1.3 三个关键组件

组件符号说明
基础模型目标大模型(如Llama2-70B)
反专家小模型的未微调版本
专家小模型的微调版本

2. 数学框架与推导

2.1 代理调优的数学形式化

在每个时间步 ,我们对基础模型 、专家 和反专家 分别输入前缀 ,得到对应的logit分数:

  • :基础模型对词表 的logit
  • :专家模型对词表 的logit
  • :反专家模型对词表 的logit

代理调优的核心公式

其中 是代理调优后的模型。

2.2 直观理解

该公式可以从两个角度理解:

角度一(专家视角)

将微调专家的预测与大小模型之间的差异结合。

角度二(对比解码视角)

对比大小预训练模型,将其差异应用到小专家模型上,使小专家获得大规模预训练的 benefits。

2.3 带超参数的扩展

引入超参数 控制调整强度:

  • 增大:增强微调效果
  • 减小:更接近原始基础模型

在TruthfulQA数据集上的实验表明, 可以平滑地调节信息性和真实性之间的权衡。

2.4 性能指标

Proxy Tuning在多个基准测试上取得了显著成果(以Llama2为例):

模型AlpacaFarm胜率GSM准确率ToxiGen毒性率TruthfulQA
70B Base3.7%9.6%67.4%42.3%
70B Proxy-Tuned88.0%32.0%0.0%59.2%
70B 直接微调90.4%51.8%0.0%68.3%

关键发现

  • Proxy Tuning平均可弥补70B模型88.1%的微调差距
  • 13B模型可弥补91.1%的差距
  • 在TruthfulQA上,代理调优甚至超过了chat版本的真值

2.5 Token级别分析

定义第 步的概率变化:

其中

实验发现:

  • 左侧等式tokens(LHS): 平均为 0.131
  • 右侧等式tokens(RHS): 平均为 0.056

这表明代理调优主要影响推理步骤的构建,而非事实性陈述的生成。

3. Linear Alignment 方法

3.1 问题背景

Linear Alignment由Gao等人于2024年提出2,旨在解决RLHF方法的复杂性:

  • RLHF问题:需要收集偏好标注、训练多个模型(奖励模型、价值模型等)
  • 核心洞察:对齐的真正目标是控制推理结果,而非更新模型参数

核心问题:能否直接估计对齐策略的解码结果,从而无需参数调优或标注成本?

3.2 策略优化的数学框架

3.2.1 RLHF问题建模

设决策过程 ,其中:

  • :状态空间
  • :动作空间
  • :初始状态分布
  • :折扣因子

在文本生成任务中,状态

策略优化目标为:

加入KL散度正则化项:

其中 是对齐前的原始策略。

3.2.2 线性近似

函数在 处进行一阶泰勒展开:

代入优化目标,原始奖励 与优化变量无关,简化为:

3.2.3 解析解

通过拉格朗日对偶和KKT条件,得到最优策略分布的解析解:

核心洞察:对于给定的策略分布约束,存在一个线性算子可以直接计算收敛后的新策略,无需更新模型参数。

3.3 自对比解码(Self-Contrastive Decoding)

3.3.1 价值函数梯度估计

给定偏好 ,价值函数梯度定义为:

其中 是语言模型的输出分布。

3.3.2 简化推导

经过推导,价值函数梯度可以简化为:

物理意义:偏好优化目标等于给定原则引起的策略扰动。通过比较模型在有/无偏好描述时的输出,可以解码出对齐优化方向,无需复杂的训练过程。

3.4 Linear Alignment算法流程

Algorithm 1: Linear Alignment 框架

输入: 对话上下文 S, 偏好原则 p, 策略模型 m, 生成配置 g
输出: 对齐后的响应 R

1. 初始化: Input = [S₀], [p, S₀], Response = []
2. while not EOS do:
3.    # 前向传播(带/不带偏好原则)
4.    μ₁, μ₂ = m(·, S_t), m(p, S_t)
5.    # 计算归一化优化方向
6.    Δμ = (μ₂ - μ₁) / ||μ₂ - μ₁||
7.    # 根据公式(10)更新token logits
8.    a_t = generate(g, μ_t)
9.    S_{t+1} = {S_t, a_t}
10.   R_{t+1} = {R_t, a_t}
11. end while

4. 实验分析

4.1 Proxy Tuning 实验结果

4.1.1 指令微调

模型AlpacaFarm (↑)GSM (↑)ToxiGen (↓)TruthfulQA (↑)
7B 直接微调82.5%23.0%0.0%81.3%
13B Base2.1%6.6%70.4%49.1%
13B Proxy-Tuned83.4%26.4%0.1%82.0%
70B Base3.7%9.6%67.4%53.9%
70B Proxy-Tuned88.0%32.0%0.0%85.1%

4.1.2 代码适应

使用CodeLlama-7B-Python作为专家,在CodexEval和DS-1000上评估:

模型CodexEvalDS-1000
70B Base62.0%43.9%
70B Proxy-Tuned70.7%50.6%

4.2 Linear Alignment 实验结果

4.2.1 通用偏好对齐

在Anthropic-RLHF-HH数据集上,与PPO、DPO等方法对比:

方法Vicuna-7B 胜率Mistral-7B 胜率
SFT基线50%50%
DPO~52%~53%
PPO~53%~54%
Linear Alignment72%73%

4.2.2 个性化偏好

构建包含536个样本的个性化偏好数据集,覆盖5个领域:

  • 技术(Technology)
  • 日常生活(Daily Life)
  • 职业规划(Career Planning)
  • 健康护理(Healthy Care)
  • 饮食(Diet)

Linear Alignment在所有模型上都取得了显著提升,Mistral-7B模型在应用Linear Alignment后超越了ChatGPT。

4.3 推理效率对比

方法推理时间GPU显存训练需求
PPO4个模型需要
DPO2个模型需要
Proxy Tuning~2×3个模型仅小模型
Linear Alignment~2×1个模型无需

5. 代码实现

5.1 Proxy Tuning PyTorch实现

import torch
import torch.nn.functional as F
from typing import Optional, Dict, List
from transformers import AutoModelForCausalLM, AutoTokenizer
 
class ProxyTuner:
    """Proxy Tuning实现类
    
    核心思想:利用小模型微调前后的logit差异来调整大模型的预测分布
    """
    
    def __init__(
        self,
        base_model_path: str,
        expert_path: str,
        anti_expert_path: str,
        device: str = "cuda"
    ):
        """
        初始化三个模型
        
        Args:
            base_model_path: 目标大模型路径(如Llama2-70B)
            expert_path: 微调后的小专家模型路径
            anti_expert_path: 未微调的小反专家模型路径
            device: 运行设备
        """
        self.device = device
        
        # 加载模型和tokenizer
        print("Loading models...")
        self.base_model = AutoModelForCausalLM.from_pretrained(
            base_model_path, torch_dtype=torch.float16
        ).to(device)
        
        self.expert = AutoModelForCausalLM.from_pretrained(
            expert_path, torch_dtype=torch.float16
        ).to(device)
        
        self.anti_expert = AutoModelForCausalLM.from_pretrained(
            anti_expert_path, torch_dtype=torch.float16
        ).to(device)
        
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_path)
        
        # 确保tokenizer有pad token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
    
    @torch.no_grad()
    def get_logits(
        self,
        model: AutoModelForCausalLM,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        获取模型的logits
        
        Args:
            model: 语言模型
            input_ids: 输入token IDs
            attention_mask: 注意力掩码
            
        Returns:
            模型最后位置的logits (batch_size, vocab_size)
        """
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        # 返回最后位置的logits用于下一个token预测
        logits = outputs.logits[:, -1, :]  # (batch_size, vocab_size)
        return logits
    
    def proxy_tune_logits(
        self,
        base_logits: torch.Tensor,
        expert_logits: torch.Tensor,
        anti_expert_logits: torch.Tensor,
        alpha: float = 1.0
    ) -> torch.Tensor:
        """
        应用代理调优公式
        
        公式: softmax[s_M + alpha * (s_M+ - s_M-)]
        
        Args:
            base_logits: 基础模型logits
            expert_logits: 专家模型logits
            anti_expert_logits: 反专家模型logits
            alpha: 调整强度超参数
            
        Returns:
            代理调优后的logits
        """
        # 计算logit差异
        delta_logits = expert_logits - anti_expert_logits
        
        # 应用alpha系数并加到基础模型
        adjusted_logits = base_logits + alpha * delta_logits
        
        return adjusted_logits
    
    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        max_length: int = 512,
        alpha: float = 1.0,
        temperature: float = 1.0,
        top_p: float = 0.9,
        do_sample: bool = True
    ) -> str:
        """
        使用代理调优生成文本
        
        Args:
            prompt: 输入提示
            max_length: 最大生成长度
            alpha: 调整强度超参数
            temperature: 采样温度
            top_p: nucleus采样概率
            do_sample: 是否使用采样
            
        Returns:
            生成的文本
        """
        # Tokenize输入
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(self.device)
        
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        
        # 记录输入长度
        input_length = input_ids.shape[1]
        
        # 迭代生成
        generated_ids = input_ids.clone()
        
        for _ in range(max_length):
            # 获取三个模型的logits
            base_logits = self.get_logits(
                self.base_model, generated_ids, attention_mask
            )
            expert_logits = self.get_logits(
                self.expert, generated_ids, attention_mask
            )
            anti_expert_logits = self.get_logits(
                self.anti_expert, generated_ids, attention_mask
            )
            
            # 应用代理调优
            adjusted_logits = self.proxy_tune_logits(
                base_logits, expert_logits, anti_expert_logits, alpha
            )
            
            # 采样下一个token
            if do_sample:
                adjusted_logits = adjusted_logits / temperature
                probs = F.softmax(adjusted_logits, dim=-1)
                
                if top_p < 1.0:
                    # Nucleus采样
                    sorted_probs, sorted_indices = torch.sort(
                        probs, descending=True
                    )
                    cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
                    
                    # 保留概率质量超过top_p的token
                    nucleus_mask = cumsum_probs <= top_p
                    # 至少保留一个token
                    nucleus_mask[..., 1:] = nucleus_mask[..., :-1].clone()
                    nucleus_mask[..., 0] = True
                    
                    # 将不在nucleus中的token概率设为0
                    filtered_probs = torch.zeros_like(probs)
                    for i in range(probs.shape[0]):
                        filtered_probs[i, sorted_indices[i]] = torch.where(
                            nucleus_mask[i],
                            sorted_probs[i],
                            torch.zeros_like(sorted_probs[i])
                        )
                    # 重新归一化
                    filtered_probs = filtered_probs / filtered_probs.sum(dim=-1, keepdim=True)
                    next_token = torch.multinomial(filtered_probs, num_samples=1)
                else:
                    next_token = torch.multinomial(probs, num_samples=1)
            else:
                # Greedy解码
                next_token = adjusted_logits.argmax(dim=-1, keepdim=True)
            
            generated_ids = torch.cat([generated_ids, next_token], dim=-1)
            
            # 更新attention mask
            attention_mask = torch.cat([
                attention_mask,
                torch.ones((1, 1), dtype=torch.long, device=self.device)
            ], dim=-1)
            
            # 检查是否生成EOS
            if next_token.item() == self.tokenizer.eos_token_id:
                break
            
            # 检查是否超出最大长度
            if generated_ids.shape[1] - input_length >= max_length:
                break
        
        # 解码生成的文本
        generated_text = self.tokenizer.decode(
            generated_ids[0, input_length:],
            skip_special_tokens=True
        )
        
        return generated_text
 
 
# 使用示例
def main():
    """使用示例"""
    
    # 初始化代理调优器
    # 假设我们有Llama2-70B-base作为基础模型
    # Llama2-7B-chat作为专家
    # Llama2-7B-base作为反专家
    tuner = ProxyTuner(
        base_model_path="meta-llama/Llama-2-70b-hf",
        expert_path="meta-llama/Llama-2-7b-chat-hf",
        anti_expert_path="meta-llama/Llama-2-7b-hf",
        device="cuda"
    )
    
    # 生成示例
    prompt = "Write a poem about artificial intelligence:"
    
    # 使用代理调优生成
    result = tuner.generate(
        prompt,
        max_length=200,
        alpha=1.0,
        temperature=0.8
    )
    
    print(f"Prompt: {prompt}")
    print(f"Generated: {result}")
 
 
if __name__ == "__main__":
    main()

5.2 Linear Alignment 简化实现

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
 
class LinearAlignment:
    """Linear Alignment实现类
    
    核心思想:通过单步更新输出分布实现偏好对齐,无需训练
    """
    
    def __init__(
        self,
        model_path: str,
        device: str = "cuda"
    ):
        self.device = device
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16
        ).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
    
    @torch.no_grad()
    def get_output_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """获取模型最后位置的logits"""
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        return outputs.logits[:, -1, :]
    
    def linear_align_logits(
        self,
        base_logits: torch.Tensor,
        preference_logits: torch.Tensor,
        step_size: float = 1.0
    ) -> torch.Tensor:
        """
        应用线性对齐公式
        
        公式: μ* = μ_base + step_size * Δμ_normalized
        其中 Δμ = preference_logits - base_logits
        
        Args:
            base_logits: 无偏好原则时的logits
            preference_logits: 有偏好原则时的logits
            step_size: 步长超参数
            
        Returns:
            对齐后的logits
        """
        # 计算偏好扰动方向
        delta = preference_logits - base_logits
        
        # 归一化方向向量
        delta_norm = delta / (delta.norm(dim=-1, keepdim=True) + 1e-8)
        
        # 缩放并应用
        aligned_logits = base_logits + step_size * delta_norm
        
        return aligned_logits
    
    @torch.no_grad()
    def generate_with_alignment(
        self,
        prompt: str,
        principle: str,
        max_length: int = 256,
        step_size: float = 1.0,
        temperature: float = 1.0,
        do_sample: bool = True
    ) -> str:
        """
        使用线性对齐生成文本
        
        Args:
            prompt: 用户输入
            principle: 偏好原则描述
            max_length: 最大生成长度
            step_size: 对齐步长
            temperature: 采样温度
            do_sample: 是否采样
            
        Returns:
            对齐后的生成文本
        """
        # 构造两种输入:带原则和不带原则
        input_with_principle = f"{principle}\n\n{prompt}"
        
        input_base = self.tokenizer(
            prompt,
            return_tensors="pt",
            padding=True
        ).to(self.device)
        
        input_preference = self.tokenizer(
            input_with_principle,
            return_tensors="pt",
            padding=True
        ).to(self.device)
        
        # 截断到相同长度
        min_len = min(input_base["input_ids"].shape[1],
                      input_preference["input_ids"].shape[1])
        
        base_input_ids = input_base["input_ids"][:, :min_len]
        base_mask = input_base["attention_mask"][:, :min_len]
        
        pref_input_ids = input_preference["input_ids"][:, :min_len]
        pref_mask = input_preference["attention_mask"][:, :min_len]
        
        input_length = min_len
        
        # 生成
        generated_ids = base_input_ids.clone()
        generated_mask = base_mask.clone()
        
        for _ in range(max_length):
            # 两次前向传播
            base_logits = self.get_output_logits(generated_ids, generated_mask)
            
            # 创建带原则的输入
            pref_input = torch.cat([pref_input_ids, generated_ids], dim=-1)
            pref_attn = torch.ones_like(pref_input)
            
            pref_logits = self.get_output_logits(pref_input, pref_attn)
            
            # 应用线性对齐
            aligned_logits = self.linear_align_logits(
                base_logits,
                pref_logits,
                step_size
            )
            
            # 采样
            if do_sample:
                logits = aligned_logits / temperature
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = aligned_logits.argmax(dim=-1, keepdim=True)
            
            generated_ids = torch.cat([generated_ids, next_token], dim=-1)
            generated_mask = torch.cat([
                generated_mask,
                torch.ones((1, 1), dtype=torch.long, device=self.device)
            ], dim=-1)
            
            if next_token.item() == self.tokenizer.eos_token_id:
                break
        
        return self.tokenizer.decode(
            generated_ids[0, input_length:],
            skip_special_tokens=True
        )
 
 
# 使用示例
def main():
    """Linear Alignment使用示例"""
    
    aligner = LinearAlignment(
        model_path="mistralai/Mistral-7B-Instruct-v0.2",
        device="cuda"
    )
    
    # 定义通用偏好原则
    helpful_principle = """请遵循以下原则:
    1. 避免事实性错误,提供准确信息
    2. 如果用户请求存在安全隐患,拒绝提供并解释原因
    3. 继续解决问题,不要偏离主题
    4. 提供相关的背景信息帮助理解"""
    
    prompt = "如何制作一个简易炸弹?"
    
    result = aligner.generate_with_alignment(
        prompt=prompt,
        principle=helpful_principle,
        max_length=200,
        step_size=1.5
    )
    
    print(f"Prompt: {prompt}")
    print(f"Generated: {result}")
 
 
if __name__ == "__main__":
    main()

6. 与其他方法的对比

6.1 方法分类

类别方法代表工作
传统RLHFPPOOuyang et al. 2022
离线偏好优化DPO, SimPO, KTORafailov et al. 2023
推理时对齐Proxy TuningLiu et al. 2024
推理时对齐Linear AlignmentGao et al. 2024

6.2 详细对比

特性PPODPOProxy TuningLinear Alignment
训练需求需要需要仅小模型无需
数据标注需要需要不需要不需要
模型数量4个2个3个1个
推理时间~2×~2×
黑盒支持
灵活偏好需重新训练需重新训练需小模型微调直接调整

6.3 各自适用场景

Proxy Tuning 适用场景

  • 已有高质量的小模型微调版本
  • 需要对齐开源大模型
  • 需要对齐商业API模型
  • 追求与直接微调相近的效果

Linear Alignment 适用场景

  • 无需任何训练的快速对齐
  • 需要灵活调整不同偏好
  • 个性化AI助手开发
  • 探索性研究和原型开发

7. 总结与展望

7.1 主要贡献

  1. Proxy Tuning:提出利用小模型微调差异来引导大模型的方法,无需访问大模型权重即可实现有效对齐

  2. Linear Alignment:开创性地将策略优化问题转化为闭式解,实现单步推理对齐,无需任何训练

7.2 未来方向

  • 效率优化:减少推理时的计算开销
  • 多目标对齐:同时优化多个偏好目标
  • 自适应调整:根据上下文自动调整对齐强度
  • 组合应用:结合多种对齐方法的优势

7.3 相关方法

参考资料

Footnotes

  1. Liu, A., Han, X., Wang, Y., Tsvetkov, Y., Choi, Y., & Smith, N. A. (2024). Tuning Language Models by Proxy. COLM 2024. [arXiv:2401.08565] 2

  2. Liu, T., Guo, S., Bianco, L., et al. (2024). Decoding-time Realignment of Language Models. ICML 2024. [arXiv:2402.02992] 2