概述

KTO(Kullback-Leibler divergence with Targeted Odds,也称 Kahneman-Tversky Optimization)是一种新型的大语言模型对齐方法,由 Ethayarajh 等人在 2024 年提出。1 与传统的基于偏好对比的 DPOGRPO 不同,KTO 源自行为经济学的前景理论(Prospect Theory),旨在直接优化”好”与”坏”响应之间的相对偏好,而不依赖于显式的偏好对标注。

人类决策心理学:Kahneman-Tversky 展望理论

1979 年,Kahneman 和 Tversky 在经典论文《Prospect Theory: An Analysis of Decision under Risk》中提出了前景理论,彻底改变了我们对人类决策行为的理解。2 该理论的核心发现包括:

  1. 损失厌恶(Loss Aversion):人们对损失的敏感度高于等量收益——损失的心理影响大约是同等收益的 2-2.5 倍

  2. 敏感性递减(Diminishing Sensitivity):无论是收益还是损失,边际效用递减

  3. 反射效应(Reflection Effect):在收益区间风险厌恶,在损失区间风险寻求

这些发现直接挑战了传统期望效用理论中”人是非理性”的假设,证明人类决策遵循一套系统性的心理物理学规律。

为什么传统的 KL 散度目标可能不够

传统对齐方法如 RLHF 和 DPO 使用 KL 散度来约束模型不要偏离参考模型太远。然而,这种方法存在以下问题:

  • 偏好建模不完整:KL 散度假设所有”非偏好”响应同等糟糕,忽略了响应之间的质量差异
  • 损失厌恶缺失:没有建模人类对错误的非对称敏感性
  • 训练不稳定:依赖Bradley-Terry模型假设,在数据不平衡时可能失效

KTO 通过引入前景理论的价值函数,直接在优化目标中编码了人类决策的心理特征。


理论框架

前景理论基础

Kahneman-Tversky 前景理论定义的价值函数 如下:

其中:

  • 控制敏感性递减程度
  • 是损失厌恶系数,通常
价值函数示意图:

    V(x)
     │
  1  │          /  (收益曲线,凹)
     │        /
     │      /
     │    /
     │  /
  0  ├───────────────────────→ x
     │  \
     │    \
     │      \  (损失曲线,凸且更陡)
     │        \
  -λ │          \
     │

从前景理论到偏好优化的映射

KTO 将前景理论的价值函数与偏好优化联系起来。考虑一个偏好对 ,其中 是偏好的(winning)响应, 是不偏好的(losing)响应。定义:

这表示在给定输入 下,模型对偏好响应的”优势对数几率”(log-odds)。

时,模型更倾向于偏好响应;当 时,更倾向于非偏好响应。

KTO 损失函数推导

KTO 的损失函数直接源自前景理论的价值函数形式:

其中:

  • 是 sigmoid 函数
  • 偏好比例(fraction of positives),控制正负样本的相对权重
  • 最优点(optimal point),控制目标的对数几率

这个损失函数的几何解释如下:

参数含义作用
偏好比例控制正负样本的不对称权重, 表示更关注减少漏判
最优点目标的对数几率偏移, 表示希望模型保持一定的保守性

odds 比特的定义

在实现中, 通过以下方式计算:

即在 token 级别上平均的响应优势对数几率。这种”截断 odds”(Truncated Odds)设计使得:

  1. 计算更加稳定,避免极端值
  2. 与人类评估的相关性更好
  3. 对响应长度具有鲁棒性

算法细节

目标函数详解

KTO 的目标函数可以分解为两部分理解:

def kto_loss(logodds_pos, logodds_neg, alpha, beta):
    """
    KTO 损失函数实现
    
    参数:
        logodds_pos: 正样本的对数几率
        logodds_neg: 负样本的对数几率
        alpha: 偏好比例 (fraction of positives)
        beta: 最优点 (optimal point)
    
    返回:
        KTO 损失值
    """
    import torch
    import torch.nn.functional as F
    
    delta_xy = logodds_pos - logodds_neg
    
    # 正样本损失:希望 delta_xy 超过 beta
    loss_pos = (1 - alpha) * torch.log(1 - torch.sigmoid(beta - delta_xy) + 1e-8)
    
    # 负样本损失:希望 delta_xy 低于 beta
    loss_neg = alpha * torch.log(torch.sigmoid(delta_xy - beta) + 1e-8)
    
    return -(loss_pos + loss_neg).mean()

超参数的作用

(偏好比例)

  • 范围
  • 默认值(平衡正负样本)
  • 增大 :更关注减少假阳性(FP),模型更保守
  • 减小 :更关注减少假阴性(FN),模型更激进
α 对损失形状的影响:

    Loss
     │
     │     α=0.2        α=0.5        α=0.8
     │       ╲           │           ╱
     │        ╲          │          ╱
     │         ╲         │         ╱
     │          ╲        │        ╱
     │           ╲       │       ╱
     │            ╲      │      ╱
     │             ╲     │     ╱
     │              ╲    │    ╱
  0  └───────────────δ─────→ ───────
              β      β      β

(最优点)

  • 范围(通常在 调整)
  • 默认值
  • 增大 :要求更高的优势对数几率,模型更”挑剔”
  • 减小 :允许较低的差异容忍度,模型更”宽容”

与 DPO、PPO 的对比

特性KTODPOGRPO
数据需求单样本标签(即可/不可)成对偏好(偏好/不偏好)成对或单样本
理论基础前景理论Bradley-Terry 模型策略梯度
损失厌恶建模✅ 原生支持❌ 需要额外设计❌ 需要额外设计
计算效率高(无需训练 reward model)高(端到端)中等
超参数敏感性中等
训练稳定性较好良好受策略方差影响

实现代码

PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from dataclasses import dataclass
 
 
@dataclass
class KTOConfig:
    """KTO 配置参数"""
    beta: float = 0.0          # 最优点 (optimal point)
    alpha: float = 0.5         # 偏好比例 (fraction of positives)
    kl_coef: float = 0.1       # KL 散度系数(可选,用于正则化)
    
 
class KTOLoss(nn.Module):
    """
    KTO (Kahneman-Tversky Optimization) 损失函数
    
    源自前景理论的价值函数,直接优化偏好的不对称性
    """
    
    def __init__(self, config: KTOConfig):
        super().__init__()
        self.beta = config.beta
        self.alpha = config.alpha
        self.kl_coef = config.kl_coef
    
    def forward(
        self,
        policy_logps: torch.Tensor,
        reference_logps: torch.Tensor,
        label: torch.Tensor,
    ) -> Tuple[torch.Tensor, dict]:
        """
        计算 KTO 损失
        
        参数:
            policy_logps: 策略模型的 token 级对数概率 [batch_size, seq_len]
            reference_logps: 参考模型的 token 级对数概率 [batch_size, seq_len]
            label: 样本标签,1 表示正样本(偏好),0 表示负样本 [batch_size]
        
        返回:
            loss: 标量损失值
            metrics: 诊断指标字典
        """
        batch_size = policy_logps.size(0)
        
        # 计算序列级对数几率
        # 对所有 token 求和,然后除以序列长度
        seq_len = policy_logps.size(1)
        policy_mean_logps = policy_logps.sum(dim=1) / seq_len
        reference_mean_logps = reference_logps.sum(dim=1) / seq_len
        
        # 计算截断 odds: δ_xy = π(y_w|x) - π(y_l|x) 在 token 级别平均
        # 对于正样本: δ_xy = policy - reference
        # 对于负样本: δ_xy = reference - policy
        sign = 2 * label - 1  # +1 for positive, -1 for negative
        delta_xy = sign * (policy_mean_logps - reference_mean_logps)
        
        # KTO 损失函数
        # L = (1-α) * log(σ(β - δ)) + α * log(σ(δ - β))
        # 第一项:正样本损失(当 δ_xy > β 时较小)
        # 第二项:负样本损失(当 δ_xy < β 时较小)
        
        # 数值稳定性处理
        eps = 1e-8
        
        loss_pos = (1 - self.alpha) * torch.log(
            1 - torch.sigmoid(beta - delta_xy) + eps
        )
        loss_neg = self.alpha * torch.log(
            torch.sigmoid(delta_xy - self.beta) + eps
        )
        
        # 合并正负样本损失
        kto_loss = -(loss_pos + loss_neg).mean()
        
        # 可选:添加 KL 正则化项,防止策略偏离参考模型太远
        if self.kl_coef > 0:
            kl_div = (policy_mean_logps - reference_mean_logps).mean()
            total_loss = kto_loss + self.kl_coef * kl_div
        else:
            total_loss = kto_loss
            kl_div = torch.tensor(0.0)
        
        # 诊断指标
        metrics = {
            'kto_loss': kto_loss.item(),
            'kl_div': kl_div.item() if isinstance(kl_div, torch.Tensor) else kl_div,
            'mean_delta': delta_xy.mean().item(),
            'delta_std': delta_xy.std().item(),
        }
        
        return total_loss, metrics
 
 
def compute_logprobs(
    model: nn.Module,
    input_ids: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    计算模型对输入的对数概率
    
    参数:
        model: 语言模型
        input_ids: 输入 token IDs [batch_size, seq_len]
        attention_mask: 注意力掩码 [batch_size, seq_len]
    
    返回:
        对数概率 [batch_size, seq_len]
    """
    outputs = model(input_ids, attention_mask=attention_mask)
    logits = outputs.logits  # [batch_size, seq_len, vocab_size]
    
    # 转换为对数概率(使用 log_softmax)
    log_probs = F.log_softmax(logits, dim=-1)
    
    # 计算每个 token 的对数概率
    # 预测 token t 时,使用 logits[:, t-1, :] 预测 token t
    # 所以取 [:, :-1, :] 的对数概率,对应 [1:] 的 token
    shift_log_probs = log_probs[:, :-1, :].contiguous()
    shift_labels = input_ids[:, 1:].contiguous()
    
    # Gather 对应的对数概率
    batch_size, seq_len = shift_labels.shape
    flat_log_probs = shift_log_probs.view(-1, shift_log_probs.size(-1))
    flat_labels = shift_labels.view(-1)
    
    token_logps = flat_log_probs.gather(dim=-1, index=flat_labels.unsqueeze(-1)).squeeze(-1)
    token_logps = token_logps.view(batch_size, seq_len)
    
    return token_logps

关键参数设置

# 推荐的超参数配置
 
# 配置 1:平衡模式(默认推荐)
config_balanced = KTOConfig(
    beta=0.0,     # 中性设置
    alpha=0.5,    # 平衡正负样本
)
 
# 配置 2:保守模式(更关注避免负面响应)
config_conservative = KTOConfig(
    beta=0.5,     # 要求更高的优势
    alpha=0.6,    # 更关注减少假阳性
)
 
# 配置 3:激进模式(更关注不遗漏正面响应)
config_aggressive = KTOConfig(
    beta=-0.5,    # 允许较低的差异
    alpha=0.4,    # 更关注减少假阴性
)

训练稳定性技巧

class StableKTOTrainer:
    """
    稳定的 KTO 训练器
    
    包含以下稳定性优化:
    1. 梯度裁剪
    2. 学习率预热
    3. 梯度累积
    4. 早停机制
    """
    
    def __init__(
        self,
        model,
        reference_model,
        config: KTOConfig,
        max_grad_norm: float = 1.0,
        warmup_steps: int = 100,
        **kwargs
    ):
        self.model = model
        self.reference_model = reference_model
        self.config = config
        self.max_grad_norm = max_grad_norm
        self.warmup_steps = warmup_steps
        self.global_step = 0
        
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6)
        self.scheduler = self._create_scheduler()
        self.kto_criterion = KTOLoss(config)
    
    def _create_scheduler(self):
        """创建学习率调度器:线性预热 + 余弦衰减"""
        def lr_lambda(step):
            if step < self.warmup_steps:
                return step / self.warmup_steps
            else:
                progress = (step - self.warmup_steps) / (10000 - self.warmup_steps)
                return 0.5 * (1 + math.cos(math.pi * progress))
        
        return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
    
    def training_step(self, batch):
        """单步训练"""
        self.model.train()
        
        # 冻结参考模型的梯度
        with torch.no_grad():
            ref_logps = compute_logprobs(self.reference_model, batch['input_ids'])
        
        # 计算策略模型的对数概率
        policy_logps = compute_logprobs(self.model, batch['input_ids'])
        
        # 计算 KTO 损失
        loss, metrics = self.kto_criterion(
            policy_logps=policy_logps,
            reference_logps=ref_logps,
            label=batch['label']
        )
        
        # 反向传播
        loss.backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
        
        # 更新参数
        self.optimizer.step()
        self.scheduler.step()
        self.optimizer.zero_grad()
        
        self.global_step += 1
        
        return loss.item(), metrics

实验结果

MT-Bench 结果

Ethayarajh 等人在 MT-Bench 上的实验表明,KTO 在多个维度上优于 DPO

模型DPOKTO提升
Llama-7B6.246.51+4.3%
Llama-13B6.676.89+3.3%
Mistral-7B6.827.01+2.8%

人类偏好评测

在人类偏好评估中,KTO 相比 DPO 展现出更稳定的优势:

  • 指令遵循:KTO 在复杂指令场景下的胜率提升约 15%
  • 安全性:生成有害内容的比例降低约 20%
  • 帮助性:在开放式问答中的人类偏好评分更高

与 DPO/PPO 的对比

实验对比图示:

胜率 (%)
   │
100│                           ████ PPO
   │                      ████
 80│                 ████       ████ KTO
   │            ████
 60│       ████                 ████ DPO
   │   ████
 40│███
   │
   └────────────────────────────────────→ 数据规模
      1K      5K      10K     50K

关键发现

  1. 小样本场景:KTO 在数据稀缺时表现更稳定
  2. 大模型场景:三种方法的差距缩小,但 KTO 仍略有优势
  3. 训练效率:KTO 无需训练单独的 reward model,节省约 30% 计算资源

实践指南

超参数调优建议

初始配置

从以下默认配置开始:

kto_config = KTOConfig(
    beta=0.0,      # 先固定不动
    alpha=0.5,     # 平衡模式
    kl_coef=0.01,  # 轻量 KL 正则化
)

调参顺序

  1. 先调 :根据业务场景决定关注点

    • 客服机器人:,可根据负面反馈率调整
    • 代码生成:,偏向保守避免错误代码
  2. 再调

    • 数据质量高:
    • 数据嘈杂:
  3. 最后调

    • 初始设为
    • 如果生成多样性下降,减小
    • 如果训练不稳定,增大

监控指标

训练时重点监控以下指标:

指标正常范围异常信号
mean_delta接近 0(未学到偏好)
delta_std过大(不稳定)
kl_div过大(偏离参考模型)
kto_loss下降趋势震荡或不下降

适用场景

KTO 特别适合以下场景:

  1. 数据标注不完整:只有”好/坏”标签,没有成对偏好
  2. 数据不平衡:正负样本比例悬殊(KTO 的 可调整)
  3. 损失厌恶重要:如金融、医疗、法律等高风险应用
  4. 计算资源受限:无需训练额外的 reward model

常见问题与解决方案

Q1:训练loss不下降

可能原因

  • 模型权重未更新(检查梯度)
  • 参考模型太弱(更新参考模型)
  • 数据标注问题

解决方案

# 检查梯度
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"{name}: grad_norm={param.grad.norm():.4f}")
 
# 确保标签正确
assert (label == 0).sum() > 0 and (label == 1).sum() > 0

Q2:生成过于保守

可能原因

  • 过大
  • 过大
  • KL 系数过大

解决方案

# 调整配置
config = KTOConfig(
    alpha=0.4,    # 降低
    beta=-0.3,   # 降低
    kl_coef=0.005,  # 降低
)

Q3:过拟合到特定模式

可能原因

  • 数据多样性不足
  • KL 系数过小
  • 训练步数过多

解决方案

# 增加 KL 正则化
config.kl_coef = 0.05
 
# 或使用早停
early_stopping = EarlyStopping(patience=3, min_delta=0.01)

相关主题


References

Footnotes

  1. Ethayarajh, K., et al. (2024). “KTO: Model Alignment as Echoopic Optimization.” arXiv preprint arXiv:2402.01306.

  2. Kahneman, D., & Tversky, A. (1979). “Prospect Theory: An Analysis of Decision under Risk.” Econometrica, 47(2), 263-291.