概述
KTO(Kullback-Leibler divergence with Targeted Odds,也称 Kahneman-Tversky Optimization)是一种新型的大语言模型对齐方法,由 Ethayarajh 等人在 2024 年提出。1 与传统的基于偏好对比的 DPO 和 GRPO 不同,KTO 源自行为经济学的前景理论(Prospect Theory),旨在直接优化”好”与”坏”响应之间的相对偏好,而不依赖于显式的偏好对标注。
人类决策心理学:Kahneman-Tversky 展望理论
1979 年,Kahneman 和 Tversky 在经典论文《Prospect Theory: An Analysis of Decision under Risk》中提出了前景理论,彻底改变了我们对人类决策行为的理解。2 该理论的核心发现包括:
-
损失厌恶(Loss Aversion):人们对损失的敏感度高于等量收益——损失的心理影响大约是同等收益的 2-2.5 倍
-
敏感性递减(Diminishing Sensitivity):无论是收益还是损失,边际效用递减
-
反射效应(Reflection Effect):在收益区间风险厌恶,在损失区间风险寻求
这些发现直接挑战了传统期望效用理论中”人是非理性”的假设,证明人类决策遵循一套系统性的心理物理学规律。
为什么传统的 KL 散度目标可能不够
传统对齐方法如 RLHF 和 DPO 使用 KL 散度来约束模型不要偏离参考模型太远。然而,这种方法存在以下问题:
- 偏好建模不完整:KL 散度假设所有”非偏好”响应同等糟糕,忽略了响应之间的质量差异
- 损失厌恶缺失:没有建模人类对错误的非对称敏感性
- 训练不稳定:依赖Bradley-Terry模型假设,在数据不平衡时可能失效
KTO 通过引入前景理论的价值函数,直接在优化目标中编码了人类决策的心理特征。
理论框架
前景理论基础
Kahneman-Tversky 前景理论定义的价值函数 如下:
其中:
- 控制敏感性递减程度
- 是损失厌恶系数,通常
价值函数示意图:
V(x)
│
1 │ / (收益曲线,凹)
│ /
│ /
│ /
│ /
0 ├───────────────────────→ x
│ \
│ \
│ \ (损失曲线,凸且更陡)
│ \
-λ │ \
│
从前景理论到偏好优化的映射
KTO 将前景理论的价值函数与偏好优化联系起来。考虑一个偏好对 ,其中 是偏好的(winning)响应, 是不偏好的(losing)响应。定义:
这表示在给定输入 下,模型对偏好响应的”优势对数几率”(log-odds)。
当 时,模型更倾向于偏好响应;当 时,更倾向于非偏好响应。
KTO 损失函数推导
KTO 的损失函数直接源自前景理论的价值函数形式:
其中:
- 是 sigmoid 函数
- 是偏好比例(fraction of positives),控制正负样本的相对权重
- 是最优点(optimal point),控制目标的对数几率
这个损失函数的几何解释如下:
| 参数 | 含义 | 作用 |
|---|---|---|
| 偏好比例 | 控制正负样本的不对称权重, 表示更关注减少漏判 | |
| 最优点 | 目标的对数几率偏移, 表示希望模型保持一定的保守性 |
odds 比特的定义
在实现中, 通过以下方式计算:
即在 token 级别上平均的响应优势对数几率。这种”截断 odds”(Truncated Odds)设计使得:
- 计算更加稳定,避免极端值
- 与人类评估的相关性更好
- 对响应长度具有鲁棒性
算法细节
目标函数详解
KTO 的目标函数可以分解为两部分理解:
def kto_loss(logodds_pos, logodds_neg, alpha, beta):
"""
KTO 损失函数实现
参数:
logodds_pos: 正样本的对数几率
logodds_neg: 负样本的对数几率
alpha: 偏好比例 (fraction of positives)
beta: 最优点 (optimal point)
返回:
KTO 损失值
"""
import torch
import torch.nn.functional as F
delta_xy = logodds_pos - logodds_neg
# 正样本损失:希望 delta_xy 超过 beta
loss_pos = (1 - alpha) * torch.log(1 - torch.sigmoid(beta - delta_xy) + 1e-8)
# 负样本损失:希望 delta_xy 低于 beta
loss_neg = alpha * torch.log(torch.sigmoid(delta_xy - beta) + 1e-8)
return -(loss_pos + loss_neg).mean()超参数的作用
(偏好比例)
- 范围:
- 默认值:(平衡正负样本)
- 增大 :更关注减少假阳性(FP),模型更保守
- 减小 :更关注减少假阴性(FN),模型更激进
α 对损失形状的影响:
Loss
│
│ α=0.2 α=0.5 α=0.8
│ ╲ │ ╱
│ ╲ │ ╱
│ ╲ │ ╱
│ ╲ │ ╱
│ ╲ │ ╱
│ ╲ │ ╱
│ ╲ │ ╱
│ ╲ │ ╱
0 └───────────────δ─────→ ───────
β β β
(最优点)
- 范围:(通常在 调整)
- 默认值:
- 增大 :要求更高的优势对数几率,模型更”挑剔”
- 减小 :允许较低的差异容忍度,模型更”宽容”
与 DPO、PPO 的对比
| 特性 | KTO | DPO | GRPO |
|---|---|---|---|
| 数据需求 | 单样本标签(即可/不可) | 成对偏好(偏好/不偏好) | 成对或单样本 |
| 理论基础 | 前景理论 | Bradley-Terry 模型 | 策略梯度 |
| 损失厌恶建模 | ✅ 原生支持 | ❌ 需要额外设计 | ❌ 需要额外设计 |
| 计算效率 | 高(无需训练 reward model) | 高(端到端) | 中等 |
| 超参数敏感性 | 中等 | 低 | 高 |
| 训练稳定性 | 较好 | 良好 | 受策略方差影响 |
实现代码
PyTorch 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from dataclasses import dataclass
@dataclass
class KTOConfig:
"""KTO 配置参数"""
beta: float = 0.0 # 最优点 (optimal point)
alpha: float = 0.5 # 偏好比例 (fraction of positives)
kl_coef: float = 0.1 # KL 散度系数(可选,用于正则化)
class KTOLoss(nn.Module):
"""
KTO (Kahneman-Tversky Optimization) 损失函数
源自前景理论的价值函数,直接优化偏好的不对称性
"""
def __init__(self, config: KTOConfig):
super().__init__()
self.beta = config.beta
self.alpha = config.alpha
self.kl_coef = config.kl_coef
def forward(
self,
policy_logps: torch.Tensor,
reference_logps: torch.Tensor,
label: torch.Tensor,
) -> Tuple[torch.Tensor, dict]:
"""
计算 KTO 损失
参数:
policy_logps: 策略模型的 token 级对数概率 [batch_size, seq_len]
reference_logps: 参考模型的 token 级对数概率 [batch_size, seq_len]
label: 样本标签,1 表示正样本(偏好),0 表示负样本 [batch_size]
返回:
loss: 标量损失值
metrics: 诊断指标字典
"""
batch_size = policy_logps.size(0)
# 计算序列级对数几率
# 对所有 token 求和,然后除以序列长度
seq_len = policy_logps.size(1)
policy_mean_logps = policy_logps.sum(dim=1) / seq_len
reference_mean_logps = reference_logps.sum(dim=1) / seq_len
# 计算截断 odds: δ_xy = π(y_w|x) - π(y_l|x) 在 token 级别平均
# 对于正样本: δ_xy = policy - reference
# 对于负样本: δ_xy = reference - policy
sign = 2 * label - 1 # +1 for positive, -1 for negative
delta_xy = sign * (policy_mean_logps - reference_mean_logps)
# KTO 损失函数
# L = (1-α) * log(σ(β - δ)) + α * log(σ(δ - β))
# 第一项:正样本损失(当 δ_xy > β 时较小)
# 第二项:负样本损失(当 δ_xy < β 时较小)
# 数值稳定性处理
eps = 1e-8
loss_pos = (1 - self.alpha) * torch.log(
1 - torch.sigmoid(beta - delta_xy) + eps
)
loss_neg = self.alpha * torch.log(
torch.sigmoid(delta_xy - self.beta) + eps
)
# 合并正负样本损失
kto_loss = -(loss_pos + loss_neg).mean()
# 可选:添加 KL 正则化项,防止策略偏离参考模型太远
if self.kl_coef > 0:
kl_div = (policy_mean_logps - reference_mean_logps).mean()
total_loss = kto_loss + self.kl_coef * kl_div
else:
total_loss = kto_loss
kl_div = torch.tensor(0.0)
# 诊断指标
metrics = {
'kto_loss': kto_loss.item(),
'kl_div': kl_div.item() if isinstance(kl_div, torch.Tensor) else kl_div,
'mean_delta': delta_xy.mean().item(),
'delta_std': delta_xy.std().item(),
}
return total_loss, metrics
def compute_logprobs(
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
计算模型对输入的对数概率
参数:
model: 语言模型
input_ids: 输入 token IDs [batch_size, seq_len]
attention_mask: 注意力掩码 [batch_size, seq_len]
返回:
对数概率 [batch_size, seq_len]
"""
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits # [batch_size, seq_len, vocab_size]
# 转换为对数概率(使用 log_softmax)
log_probs = F.log_softmax(logits, dim=-1)
# 计算每个 token 的对数概率
# 预测 token t 时,使用 logits[:, t-1, :] 预测 token t
# 所以取 [:, :-1, :] 的对数概率,对应 [1:] 的 token
shift_log_probs = log_probs[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
# Gather 对应的对数概率
batch_size, seq_len = shift_labels.shape
flat_log_probs = shift_log_probs.view(-1, shift_log_probs.size(-1))
flat_labels = shift_labels.view(-1)
token_logps = flat_log_probs.gather(dim=-1, index=flat_labels.unsqueeze(-1)).squeeze(-1)
token_logps = token_logps.view(batch_size, seq_len)
return token_logps关键参数设置
# 推荐的超参数配置
# 配置 1:平衡模式(默认推荐)
config_balanced = KTOConfig(
beta=0.0, # 中性设置
alpha=0.5, # 平衡正负样本
)
# 配置 2:保守模式(更关注避免负面响应)
config_conservative = KTOConfig(
beta=0.5, # 要求更高的优势
alpha=0.6, # 更关注减少假阳性
)
# 配置 3:激进模式(更关注不遗漏正面响应)
config_aggressive = KTOConfig(
beta=-0.5, # 允许较低的差异
alpha=0.4, # 更关注减少假阴性
)训练稳定性技巧
class StableKTOTrainer:
"""
稳定的 KTO 训练器
包含以下稳定性优化:
1. 梯度裁剪
2. 学习率预热
3. 梯度累积
4. 早停机制
"""
def __init__(
self,
model,
reference_model,
config: KTOConfig,
max_grad_norm: float = 1.0,
warmup_steps: int = 100,
**kwargs
):
self.model = model
self.reference_model = reference_model
self.config = config
self.max_grad_norm = max_grad_norm
self.warmup_steps = warmup_steps
self.global_step = 0
self.optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6)
self.scheduler = self._create_scheduler()
self.kto_criterion = KTOLoss(config)
def _create_scheduler(self):
"""创建学习率调度器:线性预热 + 余弦衰减"""
def lr_lambda(step):
if step < self.warmup_steps:
return step / self.warmup_steps
else:
progress = (step - self.warmup_steps) / (10000 - self.warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
def training_step(self, batch):
"""单步训练"""
self.model.train()
# 冻结参考模型的梯度
with torch.no_grad():
ref_logps = compute_logprobs(self.reference_model, batch['input_ids'])
# 计算策略模型的对数概率
policy_logps = compute_logprobs(self.model, batch['input_ids'])
# 计算 KTO 损失
loss, metrics = self.kto_criterion(
policy_logps=policy_logps,
reference_logps=ref_logps,
label=batch['label']
)
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
# 更新参数
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1
return loss.item(), metrics实验结果
MT-Bench 结果
Ethayarajh 等人在 MT-Bench 上的实验表明,KTO 在多个维度上优于 DPO:
| 模型 | DPO | KTO | 提升 |
|---|---|---|---|
| Llama-7B | 6.24 | 6.51 | +4.3% |
| Llama-13B | 6.67 | 6.89 | +3.3% |
| Mistral-7B | 6.82 | 7.01 | +2.8% |
人类偏好评测
在人类偏好评估中,KTO 相比 DPO 展现出更稳定的优势:
- 指令遵循:KTO 在复杂指令场景下的胜率提升约 15%
- 安全性:生成有害内容的比例降低约 20%
- 帮助性:在开放式问答中的人类偏好评分更高
与 DPO/PPO 的对比
实验对比图示:
胜率 (%)
│
100│ ████ PPO
│ ████
80│ ████ ████ KTO
│ ████
60│ ████ ████ DPO
│ ████
40│███
│
└────────────────────────────────────→ 数据规模
1K 5K 10K 50K
关键发现:
- 小样本场景:KTO 在数据稀缺时表现更稳定
- 大模型场景:三种方法的差距缩小,但 KTO 仍略有优势
- 训练效率:KTO 无需训练单独的 reward model,节省约 30% 计算资源
实践指南
超参数调优建议
初始配置
从以下默认配置开始:
kto_config = KTOConfig(
beta=0.0, # 先固定不动
alpha=0.5, # 平衡模式
kl_coef=0.01, # 轻量 KL 正则化
)调参顺序
-
先调 :根据业务场景决定关注点
- 客服机器人:,可根据负面反馈率调整
- 代码生成:,偏向保守避免错误代码
-
再调 :
- 数据质量高:
- 数据嘈杂:
-
最后调 :
- 初始设为
- 如果生成多样性下降,减小
- 如果训练不稳定,增大
监控指标
训练时重点监控以下指标:
| 指标 | 正常范围 | 异常信号 |
|---|---|---|
mean_delta | 接近 0(未学到偏好) | |
delta_std | 过大(不稳定) | |
kl_div | 过大(偏离参考模型) | |
kto_loss | 下降趋势 | 震荡或不下降 |
适用场景
KTO 特别适合以下场景:
- 数据标注不完整:只有”好/坏”标签,没有成对偏好
- 数据不平衡:正负样本比例悬殊(KTO 的 可调整)
- 损失厌恶重要:如金融、医疗、法律等高风险应用
- 计算资源受限:无需训练额外的 reward model
常见问题与解决方案
Q1:训练loss不下降
可能原因:
- 模型权重未更新(检查梯度)
- 参考模型太弱(更新参考模型)
- 数据标注问题
解决方案:
# 检查梯度
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: grad_norm={param.grad.norm():.4f}")
# 确保标签正确
assert (label == 0).sum() > 0 and (label == 1).sum() > 0Q2:生成过于保守
可能原因:
- 过大
- 过大
- KL 系数过大
解决方案:
# 调整配置
config = KTOConfig(
alpha=0.4, # 降低
beta=-0.3, # 降低
kl_coef=0.005, # 降低
)Q3:过拟合到特定模式
可能原因:
- 数据多样性不足
- KL 系数过小
- 训练步数过多
解决方案:
# 增加 KL 正则化
config.kl_coef = 0.05
# 或使用早停
early_stopping = EarlyStopping(patience=3, min_delta=0.01)