代理调优与线性对齐方法
大语言模型(LLM)的对齐问题是构建有用、安全AI系统的核心挑战。传统方法如RLHF(基于人类反馈的强化学习)虽然有效,但存在训练复杂、计算成本高等问题。近年来,推理时对齐(Decoding-time Alignment) 方法崭露头角,其中 Proxy Tuning 和 Linear Alignment 是两个代表性工作。12
1. Proxy Tuning 核心思想
1.1 问题背景
随着LLM规模的增长,直接微调大模型变得资源密集甚至不可能。当模型权重不可访问(如GPT-4等商业API)时,传统微调方法完全失效。
核心问题:如何在不修改大模型参数的情况下,实现类似于微调的效果?
1.2 方法概述
Proxy Tuning由Liu等人于2024年提出1,其核心思想是:
调优一个小模型,然后将小模型微调前后的预测差异应用到更大的目标模型上。
该方法仅需要访问模型的输出logits(词汇表上的概率分布),不需要访问模型内部权重。
1.3 三个关键组件
| 组件 | 符号 | 说明 |
|---|---|---|
| 基础模型 | 目标大模型(如Llama2-70B) | |
| 反专家 | 小模型的未微调版本 | |
| 专家 | 小模型的微调版本 |
2. 数学框架与推导
2.1 代理调优的数学形式化
在每个时间步 ,我们对基础模型 、专家 和反专家 分别输入前缀 ,得到对应的logit分数:
- :基础模型对词表 的logit
- :专家模型对词表 的logit
- :反专家模型对词表 的logit
代理调优的核心公式:
其中 是代理调优后的模型。
2.2 直观理解
该公式可以从两个角度理解:
角度一(专家视角):
将微调专家的预测与大小模型之间的差异结合。
角度二(对比解码视角):
对比大小预训练模型,将其差异应用到小专家模型上,使小专家获得大规模预训练的 benefits。
2.3 带超参数的扩展
引入超参数 控制调整强度:
- 增大:增强微调效果
- 减小:更接近原始基础模型
在TruthfulQA数据集上的实验表明, 可以平滑地调节信息性和真实性之间的权衡。
2.4 性能指标
Proxy Tuning在多个基准测试上取得了显著成果(以Llama2为例):
| 模型 | AlpacaFarm胜率 | GSM准确率 | ToxiGen毒性率 | TruthfulQA |
|---|---|---|---|---|
| 70B Base | 3.7% | 9.6% | 67.4% | 42.3% |
| 70B Proxy-Tuned | 88.0% | 32.0% | 0.0% | 59.2% |
| 70B 直接微调 | 90.4% | 51.8% | 0.0% | 68.3% |
关键发现:
- Proxy Tuning平均可弥补70B模型88.1%的微调差距
- 13B模型可弥补91.1%的差距
- 在TruthfulQA上,代理调优甚至超过了chat版本的真值
2.5 Token级别分析
定义第 步的概率变化:
其中 。
实验发现:
- 左侧等式tokens(LHS): 平均为 0.131
- 右侧等式tokens(RHS): 平均为 0.056
这表明代理调优主要影响推理步骤的构建,而非事实性陈述的生成。
3. Linear Alignment 方法
3.1 问题背景
Linear Alignment由Gao等人于2024年提出2,旨在解决RLHF方法的复杂性:
- RLHF问题:需要收集偏好标注、训练多个模型(奖励模型、价值模型等)
- 核心洞察:对齐的真正目标是控制推理结果,而非更新模型参数
核心问题:能否直接估计对齐策略的解码结果,从而无需参数调优或标注成本?
3.2 策略优化的数学框架
3.2.1 RLHF问题建模
设决策过程 ,其中:
- :状态空间
- :动作空间
- :初始状态分布
- :折扣因子
在文本生成任务中,状态 。
策略优化目标为:
加入KL散度正则化项:
其中 是对齐前的原始策略。
3.2.2 线性近似
对 函数在 处进行一阶泰勒展开:
代入优化目标,原始奖励 与优化变量无关,简化为:
3.2.3 解析解
通过拉格朗日对偶和KKT条件,得到最优策略分布的解析解:
核心洞察:对于给定的策略分布约束,存在一个线性算子可以直接计算收敛后的新策略,无需更新模型参数。
3.3 自对比解码(Self-Contrastive Decoding)
3.3.1 价值函数梯度估计
给定偏好 ,价值函数梯度定义为:
其中 是语言模型的输出分布。
3.3.2 简化推导
经过推导,价值函数梯度可以简化为:
物理意义:偏好优化目标等于给定原则引起的策略扰动。通过比较模型在有/无偏好描述时的输出,可以解码出对齐优化方向,无需复杂的训练过程。
3.4 Linear Alignment算法流程
Algorithm 1: Linear Alignment 框架
输入: 对话上下文 S, 偏好原则 p, 策略模型 m, 生成配置 g
输出: 对齐后的响应 R
1. 初始化: Input = [S₀], [p, S₀], Response = []
2. while not EOS do:
3. # 前向传播(带/不带偏好原则)
4. μ₁, μ₂ = m(·, S_t), m(p, S_t)
5. # 计算归一化优化方向
6. Δμ = (μ₂ - μ₁) / ||μ₂ - μ₁||
7. # 根据公式(10)更新token logits
8. a_t = generate(g, μ_t)
9. S_{t+1} = {S_t, a_t}
10. R_{t+1} = {R_t, a_t}
11. end while
4. 实验分析
4.1 Proxy Tuning 实验结果
4.1.1 指令微调
| 模型 | AlpacaFarm (↑) | GSM (↑) | ToxiGen (↓) | TruthfulQA (↑) |
|---|---|---|---|---|
| 7B 直接微调 | 82.5% | 23.0% | 0.0% | 81.3% |
| 13B Base | 2.1% | 6.6% | 70.4% | 49.1% |
| 13B Proxy-Tuned | 83.4% | 26.4% | 0.1% | 82.0% |
| 70B Base | 3.7% | 9.6% | 67.4% | 53.9% |
| 70B Proxy-Tuned | 88.0% | 32.0% | 0.0% | 85.1% |
4.1.2 代码适应
使用CodeLlama-7B-Python作为专家,在CodexEval和DS-1000上评估:
| 模型 | CodexEval | DS-1000 |
|---|---|---|
| 70B Base | 62.0% | 43.9% |
| 70B Proxy-Tuned | 70.7% | 50.6% |
4.2 Linear Alignment 实验结果
4.2.1 通用偏好对齐
在Anthropic-RLHF-HH数据集上,与PPO、DPO等方法对比:
| 方法 | Vicuna-7B 胜率 | Mistral-7B 胜率 |
|---|---|---|
| SFT基线 | 50% | 50% |
| DPO | ~52% | ~53% |
| PPO | ~53% | ~54% |
| Linear Alignment | 72% | 73% |
4.2.2 个性化偏好
构建包含536个样本的个性化偏好数据集,覆盖5个领域:
- 技术(Technology)
- 日常生活(Daily Life)
- 职业规划(Career Planning)
- 健康护理(Healthy Care)
- 饮食(Diet)
Linear Alignment在所有模型上都取得了显著提升,Mistral-7B模型在应用Linear Alignment后超越了ChatGPT。
4.3 推理效率对比
| 方法 | 推理时间 | GPU显存 | 训练需求 |
|---|---|---|---|
| PPO | 1× | 4个模型 | 需要 |
| DPO | 1× | 2个模型 | 需要 |
| Proxy Tuning | ~2× | 3个模型 | 仅小模型 |
| Linear Alignment | ~2× | 1个模型 | 无需 |
5. 代码实现
5.1 Proxy Tuning PyTorch实现
import torch
import torch.nn.functional as F
from typing import Optional, Dict, List
from transformers import AutoModelForCausalLM, AutoTokenizer
class ProxyTuner:
"""Proxy Tuning实现类
核心思想:利用小模型微调前后的logit差异来调整大模型的预测分布
"""
def __init__(
self,
base_model_path: str,
expert_path: str,
anti_expert_path: str,
device: str = "cuda"
):
"""
初始化三个模型
Args:
base_model_path: 目标大模型路径(如Llama2-70B)
expert_path: 微调后的小专家模型路径
anti_expert_path: 未微调的小反专家模型路径
device: 运行设备
"""
self.device = device
# 加载模型和tokenizer
print("Loading models...")
self.base_model = AutoModelForCausalLM.from_pretrained(
base_model_path, torch_dtype=torch.float16
).to(device)
self.expert = AutoModelForCausalLM.from_pretrained(
expert_path, torch_dtype=torch.float16
).to(device)
self.anti_expert = AutoModelForCausalLM.from_pretrained(
anti_expert_path, torch_dtype=torch.float16
).to(device)
self.tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# 确保tokenizer有pad token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
@torch.no_grad()
def get_logits(
self,
model: AutoModelForCausalLM,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
获取模型的logits
Args:
model: 语言模型
input_ids: 输入token IDs
attention_mask: 注意力掩码
Returns:
模型最后位置的logits (batch_size, vocab_size)
"""
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# 返回最后位置的logits用于下一个token预测
logits = outputs.logits[:, -1, :] # (batch_size, vocab_size)
return logits
def proxy_tune_logits(
self,
base_logits: torch.Tensor,
expert_logits: torch.Tensor,
anti_expert_logits: torch.Tensor,
alpha: float = 1.0
) -> torch.Tensor:
"""
应用代理调优公式
公式: softmax[s_M + alpha * (s_M+ - s_M-)]
Args:
base_logits: 基础模型logits
expert_logits: 专家模型logits
anti_expert_logits: 反专家模型logits
alpha: 调整强度超参数
Returns:
代理调优后的logits
"""
# 计算logit差异
delta_logits = expert_logits - anti_expert_logits
# 应用alpha系数并加到基础模型
adjusted_logits = base_logits + alpha * delta_logits
return adjusted_logits
@torch.no_grad()
def generate(
self,
prompt: str,
max_length: int = 512,
alpha: float = 1.0,
temperature: float = 1.0,
top_p: float = 0.9,
do_sample: bool = True
) -> str:
"""
使用代理调优生成文本
Args:
prompt: 输入提示
max_length: 最大生成长度
alpha: 调整强度超参数
temperature: 采样温度
top_p: nucleus采样概率
do_sample: 是否使用采样
Returns:
生成的文本
"""
# Tokenize输入
inputs = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True
).to(self.device)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# 记录输入长度
input_length = input_ids.shape[1]
# 迭代生成
generated_ids = input_ids.clone()
for _ in range(max_length):
# 获取三个模型的logits
base_logits = self.get_logits(
self.base_model, generated_ids, attention_mask
)
expert_logits = self.get_logits(
self.expert, generated_ids, attention_mask
)
anti_expert_logits = self.get_logits(
self.anti_expert, generated_ids, attention_mask
)
# 应用代理调优
adjusted_logits = self.proxy_tune_logits(
base_logits, expert_logits, anti_expert_logits, alpha
)
# 采样下一个token
if do_sample:
adjusted_logits = adjusted_logits / temperature
probs = F.softmax(adjusted_logits, dim=-1)
if top_p < 1.0:
# Nucleus采样
sorted_probs, sorted_indices = torch.sort(
probs, descending=True
)
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
# 保留概率质量超过top_p的token
nucleus_mask = cumsum_probs <= top_p
# 至少保留一个token
nucleus_mask[..., 1:] = nucleus_mask[..., :-1].clone()
nucleus_mask[..., 0] = True
# 将不在nucleus中的token概率设为0
filtered_probs = torch.zeros_like(probs)
for i in range(probs.shape[0]):
filtered_probs[i, sorted_indices[i]] = torch.where(
nucleus_mask[i],
sorted_probs[i],
torch.zeros_like(sorted_probs[i])
)
# 重新归一化
filtered_probs = filtered_probs / filtered_probs.sum(dim=-1, keepdim=True)
next_token = torch.multinomial(filtered_probs, num_samples=1)
else:
next_token = torch.multinomial(probs, num_samples=1)
else:
# Greedy解码
next_token = adjusted_logits.argmax(dim=-1, keepdim=True)
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
# 更新attention mask
attention_mask = torch.cat([
attention_mask,
torch.ones((1, 1), dtype=torch.long, device=self.device)
], dim=-1)
# 检查是否生成EOS
if next_token.item() == self.tokenizer.eos_token_id:
break
# 检查是否超出最大长度
if generated_ids.shape[1] - input_length >= max_length:
break
# 解码生成的文本
generated_text = self.tokenizer.decode(
generated_ids[0, input_length:],
skip_special_tokens=True
)
return generated_text
# 使用示例
def main():
"""使用示例"""
# 初始化代理调优器
# 假设我们有Llama2-70B-base作为基础模型
# Llama2-7B-chat作为专家
# Llama2-7B-base作为反专家
tuner = ProxyTuner(
base_model_path="meta-llama/Llama-2-70b-hf",
expert_path="meta-llama/Llama-2-7b-chat-hf",
anti_expert_path="meta-llama/Llama-2-7b-hf",
device="cuda"
)
# 生成示例
prompt = "Write a poem about artificial intelligence:"
# 使用代理调优生成
result = tuner.generate(
prompt,
max_length=200,
alpha=1.0,
temperature=0.8
)
print(f"Prompt: {prompt}")
print(f"Generated: {result}")
if __name__ == "__main__":
main()5.2 Linear Alignment 简化实现
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
class LinearAlignment:
"""Linear Alignment实现类
核心思想:通过单步更新输出分布实现偏好对齐,无需训练
"""
def __init__(
self,
model_path: str,
device: str = "cuda"
):
self.device = device
self.model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float16
).to(device)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
@torch.no_grad()
def get_output_logits(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor:
"""获取模型最后位置的logits"""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
return outputs.logits[:, -1, :]
def linear_align_logits(
self,
base_logits: torch.Tensor,
preference_logits: torch.Tensor,
step_size: float = 1.0
) -> torch.Tensor:
"""
应用线性对齐公式
公式: μ* = μ_base + step_size * Δμ_normalized
其中 Δμ = preference_logits - base_logits
Args:
base_logits: 无偏好原则时的logits
preference_logits: 有偏好原则时的logits
step_size: 步长超参数
Returns:
对齐后的logits
"""
# 计算偏好扰动方向
delta = preference_logits - base_logits
# 归一化方向向量
delta_norm = delta / (delta.norm(dim=-1, keepdim=True) + 1e-8)
# 缩放并应用
aligned_logits = base_logits + step_size * delta_norm
return aligned_logits
@torch.no_grad()
def generate_with_alignment(
self,
prompt: str,
principle: str,
max_length: int = 256,
step_size: float = 1.0,
temperature: float = 1.0,
do_sample: bool = True
) -> str:
"""
使用线性对齐生成文本
Args:
prompt: 用户输入
principle: 偏好原则描述
max_length: 最大生成长度
step_size: 对齐步长
temperature: 采样温度
do_sample: 是否采样
Returns:
对齐后的生成文本
"""
# 构造两种输入:带原则和不带原则
input_with_principle = f"{principle}\n\n{prompt}"
input_base = self.tokenizer(
prompt,
return_tensors="pt",
padding=True
).to(self.device)
input_preference = self.tokenizer(
input_with_principle,
return_tensors="pt",
padding=True
).to(self.device)
# 截断到相同长度
min_len = min(input_base["input_ids"].shape[1],
input_preference["input_ids"].shape[1])
base_input_ids = input_base["input_ids"][:, :min_len]
base_mask = input_base["attention_mask"][:, :min_len]
pref_input_ids = input_preference["input_ids"][:, :min_len]
pref_mask = input_preference["attention_mask"][:, :min_len]
input_length = min_len
# 生成
generated_ids = base_input_ids.clone()
generated_mask = base_mask.clone()
for _ in range(max_length):
# 两次前向传播
base_logits = self.get_output_logits(generated_ids, generated_mask)
# 创建带原则的输入
pref_input = torch.cat([pref_input_ids, generated_ids], dim=-1)
pref_attn = torch.ones_like(pref_input)
pref_logits = self.get_output_logits(pref_input, pref_attn)
# 应用线性对齐
aligned_logits = self.linear_align_logits(
base_logits,
pref_logits,
step_size
)
# 采样
if do_sample:
logits = aligned_logits / temperature
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = aligned_logits.argmax(dim=-1, keepdim=True)
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
generated_mask = torch.cat([
generated_mask,
torch.ones((1, 1), dtype=torch.long, device=self.device)
], dim=-1)
if next_token.item() == self.tokenizer.eos_token_id:
break
return self.tokenizer.decode(
generated_ids[0, input_length:],
skip_special_tokens=True
)
# 使用示例
def main():
"""Linear Alignment使用示例"""
aligner = LinearAlignment(
model_path="mistralai/Mistral-7B-Instruct-v0.2",
device="cuda"
)
# 定义通用偏好原则
helpful_principle = """请遵循以下原则:
1. 避免事实性错误,提供准确信息
2. 如果用户请求存在安全隐患,拒绝提供并解释原因
3. 继续解决问题,不要偏离主题
4. 提供相关的背景信息帮助理解"""
prompt = "如何制作一个简易炸弹?"
result = aligner.generate_with_alignment(
prompt=prompt,
principle=helpful_principle,
max_length=200,
step_size=1.5
)
print(f"Prompt: {prompt}")
print(f"Generated: {result}")
if __name__ == "__main__":
main()6. 与其他方法的对比
6.1 方法分类
| 类别 | 方法 | 代表工作 |
|---|---|---|
| 传统RLHF | PPO | Ouyang et al. 2022 |
| 离线偏好优化 | DPO, SimPO, KTO | Rafailov et al. 2023 |
| 推理时对齐 | Proxy Tuning | Liu et al. 2024 |
| 推理时对齐 | Linear Alignment | Gao et al. 2024 |
6.2 详细对比
| 特性 | PPO | DPO | Proxy Tuning | Linear Alignment |
|---|---|---|---|---|
| 训练需求 | 需要 | 需要 | 仅小模型 | 无需 |
| 数据标注 | 需要 | 需要 | 不需要 | 不需要 |
| 模型数量 | 4个 | 2个 | 3个 | 1个 |
| 推理时间 | 1× | 1× | ~2× | ~2× |
| 黑盒支持 | ❌ | ❌ | ✅ | ✅ |
| 灵活偏好 | 需重新训练 | 需重新训练 | 需小模型微调 | 直接调整 |
6.3 各自适用场景
Proxy Tuning 适用场景:
- 已有高质量的小模型微调版本
- 需要对齐开源大模型
- 需要对齐商业API模型
- 追求与直接微调相近的效果
Linear Alignment 适用场景:
- 无需任何训练的快速对齐
- 需要灵活调整不同偏好
- 个性化AI助手开发
- 探索性研究和原型开发
7. 总结与展望
7.1 主要贡献
-
Proxy Tuning:提出利用小模型微调差异来引导大模型的方法,无需访问大模型权重即可实现有效对齐
-
Linear Alignment:开创性地将策略优化问题转化为闭式解,实现单步推理对齐,无需任何训练
7.2 未来方向
- 效率优化:减少推理时的计算开销
- 多目标对齐:同时优化多个偏好目标
- 自适应调整:根据上下文自动调整对齐强度
- 组合应用:结合多种对齐方法的优势
7.3 相关方法
参考资料
Footnotes
-
Liu, A., Han, X., Wang, Y., Tsvetkov, Y., Choi, Y., & Smith, N. A. (2024). Tuning Language Models by Proxy. COLM 2024. [arXiv:2401.08565] ↩ ↩2
-
Liu, T., Guo, S., Bianco, L., et al. (2024). Decoding-time Realignment of Language Models. ICML 2024. [arXiv:2402.02992] ↩ ↩2