概述
VDS-TTT(Verifier-Driven Sample Selection for Test-Time Training,验证器驱动的测试时训练样本选择)是一种自监督框架,用于在测试时持续改进预训练大语言模型(LLM)。该框架由 Moradi 等人于2025年提出,发表于 NeurIPS 2025 Workshop “AI That Keeps Up: Workshop on Continual and Compatible Foundation Model Updates (CCFM)”。12
核心思想:对于每个测试查询,LLM生成 个候选答案,由验证器评分,仅选择高置信度的伪标签样本进行 LoRA 微调,实现无需人工标注的持续自我改进。
问题背景
传统 LLM 的局限性
大规模语言模型(LLM)在过去几年取得了令人瞩目的进展,从 GPT 系列到 Claude、LLaMA 等开源模型,这些系统在各种自然语言处理任务上展现出卓越的能力。从 GPT-3 的 1750 亿参数到 GPT-4 的多模态能力,再到 LLaMA 系列模型的开源化,LLM 已经成为了人工智能领域最具影响力的技术突破之一。然而,当前的 LLM 部署范式仍然存在根本性的限制,这些限制在现实应用中变得越来越突出。
固定的训练-测试范式:传统 LLM 采用「一次性训练,持续推理」的策略。模型参数在预训练和微调阶段确定后,在推理阶段保持完全冻结。这种设计虽然简化了部署流程,但牺牲了模型适应新情况的能力。这种僵化性在以下几个维度上表现尤为明显:
- 无法修正错误:模型一旦部署,就无法根据实际使用中的反馈进行自我修正
- 缺乏持续学习:模型无法像人类一样从新任务中持续学习和积累知识
- 资源浪费:每个新任务都需要重新训练或微调模型,造成计算资源的浪费
分布外泛化的挑战:研究表明,即使是参数量达到数千亿的前沿模型,在面对与训练数据分布显著不同的测试样本时,其表现也会大幅下降。这种现象在以下场景中尤为明显:
- 结构新颖的推理任务:问题形式或解题思路与训练语料差异较大。例如,数学竞赛中的证明题往往需要创造性的解题技巧,这些在常规训练数据中很少出现
- 领域迁移场景:从通用领域迁移到医学、法律、金融等垂直领域时,模型常常表现出明显的「水土不服」
- 时间分布偏移:模型训练后出现的新概念、新术语或新的表达方式,模型无法自动适应这些变化
- 对抗性样本:刻意设计的输入可能导致模型产生错误但自信的输出
标签稀缺的现实困境:在实际应用中,获取高质量的标注数据往往代价高昂且耗时。对于特定领域的专业任务,如复杂的数学证明验证、代码 bug 修复、法律条文解读等,标注工作需要领域专家的参与,成本极高。这导致传统的监督微调方法在这些场景中难以应用。此外,标注数据的质量本身也难以保证,标注错误会直接传导到模型性能上。
固定的训练-测试范式:传统 LLM 采用「一次性训练,持续推理」的策略。模型参数在预训练和微调阶段确定后,在推理阶段保持完全冻结。这种设计虽然简化了部署流程,但牺牲了模型适应新情况的能力。在面对训练数据中未充分覆盖的模式时,模型往往表现出明显的性能退化。
分布外泛化的挑战:研究表明,即使是参数量达到数千亿的前沿模型,在面对与训练数据分布显著不同的测试样本时,其表现也会大幅下降。这种现象在以下场景中尤为明显:
- 结构新颖的推理任务:问题形式或解题思路与训练语料差异较大
- 领域迁移场景:从通用领域迁移到医学、法律、金融等垂直领域
- 时间分布偏移:模型训练后出现的新概念、新术语或新的表达方式
标签稀缺的现实困境:在实际应用中,获取高质量的标注数据往往代价高昂且耗时。对于特定领域的专业任务,如复杂的数学证明验证、代码 bug 修复、法律条文解读等,标注工作需要领域专家的参与,成本极高。这导致传统的监督微调方法在这些场景中难以应用。
测试时训练(TTT)的兴起
测试时训练(Test-Time Training, TTT)范式重新审视了机器学习中的经典思想:在测试时更新模型参数,利用每个测试实例特有的信息进行自监督学习。与传统的归纳学习(从训练数据泛化到未见示例)不同,TTT 采用转导学习策略,使模型能够在测试时动态适应。
TTT 的核心优势在于:
- 无需标注数据:可以利用测试样本本身的无监督信号进行学习
- 动态适应:针对每个测试实例调整模型,而非使用固定参数
- 处理分布偏移:更好地应对训练与测试分布不一致的情况
TTT 的发展历程可以追溯到机器学习的早期研究。传统 TTT 方法通常依赖于特定的辅助任务,例如图像领域的旋转预测任务、音频领域的重建任务等。这些辅助任务为测试时的模型更新提供了自监督信号。然而,将这些方法直接迁移到 LLM 领域面临巨大挑战:LLM 处理的是离散的文本序列,其内在的自监督任务设计远比图像领域的简单变换复杂。
VDS-TTT 的核心贡献在于解决了 TTT 的一个关键挑战:如何确保伪标签的质量?仅依靠模型自身生成的响应作为训练信号,容易引入噪声,导致性能退化。VDS-TTT 通过引入验证器来筛选高质量样本,有效缓解了这一问题。
此外,VDS-TTT 还解决了验证器选择方法(如纯投票法)在面对异构解分布时的不稳定性问题。当不同候选响应采用完全不同的解题策略时,简单多数投票可能错误地倾向于次优解。VDS-TTT 通过验证器的置信度评分,能够更准确地识别真正可靠的响应。
VDS-TTT 的创新点
VDS-TTT 在以下几个方面实现了创新:
1. 验证器驱动的样本筛选:不同于传统 TTT 随机使用测试样本,VDS-TTT 使用专门的验证器评估每个候选响应的质量,仅选择高置信度样本进行训练。这种选择性策略确保了训练数据的可靠性。
2. 参数高效的测试时更新:通过 LoRA 技术,VDS-TTT 只需更新极少量的参数(约占模型总参数的 0.1%-1%),大大降低了测试时训练的计算开销和内存需求。
3. 无需人工标注:整个框架完全自监督运行,不依赖任何人工标注数据,真正实现了「边用边学」的持续改进范式。
4. 跨领域通用性:框架设计不针对特定任务,可以灵活适配不同的验证器以应对各种推理任务。
核心框架
整体架构
VDS-TTT 的整体架构可以划分为以下几个核心组件:
┌─────────────────────────────────────────────────────────────────────┐
│ VDS-TTT 系统架构 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ │
│ │ 查询输入 │ q_i │
│ └──────┬───────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ Stage 1: 候选生成器 │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ LLM f_θ₀ + 温度采样 T > 0 │ │ │
│ │ │ ───────────────────────────────────────────── │ │ │
│ │ │ 输出: N 个候选响应 {r₁, r₂, ..., r_N} │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ Stage 2: 验证器筛选器 │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ 验证器 s(·, ·) │ │ │
│ │ │ ───────────────────────────────────────────── │ │ │
│ │ │ 对每个候选评分 → 置信度分数 │ │ │
│ │ │ 选择最佳候选 → 与阈值 τ 比较 │ │ │
│ │ │ ───────────────────────────────────────────── │ │ │
│ │ │ 输出: 高置信度伪标签 (q_i, r*) │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ Stage 3: LoRA 微调器 │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ 仅更新 LoRA 适配器参数 Δ │ │ │
│ │ │ 冻结基础模型权重 θ₀ │ │ │
│ │ │ ───────────────────────────────────────────── │ │ │
│ │ │ 最小化 SFT 损失函数 │ │ │
│ │ │ ───────────────────────────────────────────── │ │ │
│ │ │ 输出: 更新的适配器 Δ' │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ 模型输出 │ r* │
│ └──────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘
三阶段流程
VDS-TTT 的工作流程包含三个顺序执行的阶段:
┌─────────────────────────────────────────────────────────────────┐
│ VDS-TTT 框架流程 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Stage 1: 候选生成 (Candidate Generation) │
│ ┌────────────┐ N 个候选答案 │
│ │ 查询 qi │ ───▶ fθ₀ ──▶ {r₁, r₂, ..., rₙ} │
│ └────────────┘ 温度采样 T > 0 │
│ │
│ Stage 2: 置信度筛选 (Confidence-Guided Selection) │
│ ┌────────────┐ │
│ │ 验证器 │ ───▶ 评分 s(rj, aj) ───▶ 选择最佳响应 │
│ │ s(·,·) │ 阈值 τ 过滤 │
│ └────────────┘ │
│ │
│ Stage 3: 测试时训练 (Test-Time Training) │
│ ┌────────────────────────────────────┐ │
│ │ 伪标签 (qi, r*) │ │
│ │ LoRA 轻量微调 ──▶ 更新适配器参数 Δ │ │
│ │ 冻结基础模型权重 θ₀ │ │
│ └────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
数学形式化
设 为预训练 LLM, 为输入查询, 为验证器评分函数。
阶段一:候选生成
对于每个输入查询 ,模型使用温度采样()生成 个多样化的候选响应:
从每个响应 中提取最终答案 ,形成候选对 。
阶段二:置信度筛选
验证器对每个候选响应-答案对进行评分:
选择得分最高的响应作为临时最优响应:
应用置信度阈值 进行过滤:
阶段三:LoRA 微调
仅更新 LoRA 适配器参数 ,最小化以下监督微调(SFT)损失函数:
参数更新采用梯度下降:
算法伪代码
以下是 VDS-TTT 的算法伪代码,展示完整的处理流程:
Algorithm 1: VDS-TTT (Verifier-Driven Sample Selection for TTT)
Input:
- Pretrained LLM f_{θ₀}
- Verifier score function s(·, ·)
- Temperature T > 0
- Number of samples N
- Score threshold τ
- LoRA adapter steps M
- Learning rate η
Output:
- Adapted adapter parameters Δ
Initialize LoRA adapter with parameters Δ
for each test query q_i do
# Stage 1: Candidate Generation
for j = 1 to N do
r_j ← f_{θ₀}(q_i; temperature=T) # 采样生成候选
a_j ← extract_answer(r_j) # 提取答案
end for
# Stage 2: Confidence-Guided Selection
for j = 1 to N do
s_{ij} ← s(r_j, a_j) # 验证器评分
end for
j* ← argmax_j s_{ij} # 选择得分最高的候选
s* ← s_{ij*} # 获取最高分数
if s* < τ then
continue # 跳过低于阈值的样本
end if
(r*, a*) ← (r_{j*}, a_{j*}) # 确定伪标签
# Stage 3: Test-Time Training (LoRA Adaptation)
for step = 1 to M do
Δ ← Δ - η · ∇_Δ L_{SFT}(Δ) # 梯度更新
end for
end for
return Δ
损失函数详解
VDS-TTT 采用的监督微调损失函数是标准语言模型损失的变体。给定查询-响应对 ,损失函数定义为:
这个损失函数的物理含义是:最大化在给定查询和前文条件下,模型生成正确目标词的概率。具体而言:
- 自回归建模:LLM 本质上是一个自回归模型,每个 token 的生成概率都条件依赖于前面的所有 token
- 序列级别的监督:与单纯的 token 级损失不同,我们希望整个响应序列的概率最大化
- 高效计算:通过教师强制(Teacher Forcing)技术,我们可以并行计算序列中所有位置的损失
在实践中,VDS-TTT 使用标准的交叉熵损失实现:
# PyTorch 实现
def compute_sft_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
计算监督微调损失。
Args:
logits: [batch_size, seq_len, vocab_size] 模型输出的未归一化对数概率
labels: [batch_size, seq_len] 目标 token ID
Returns:
标量损失值
"""
# 移位:预测第 t+1 个 token 时,使用第 t 个 token 的标签
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 计算交叉熵损失
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) # 忽略 padding
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
return loss验证器机制
验证器的作用与设计原则
验证器在 VDS-TTT 框架中扮演着「质量守门人」的角色,其设计直接影响最终的自改进效果。一个优秀的验证器应当满足以下设计原则:
1. 判别能力(Discriminative Power)
验证器必须能够有效区分正确和错误的响应。这种区分能力不仅体现在最终答案的对错,还应考虑推理过程的合理性。
2. 校准性(Calibration)
验证器给出的置信度分数应当与真实正确概率保持一致。良好的校准性对于设置合理的置信度阈值至关重要。
3. 高效性(Efficiency)
由于验证器需要在每个候选响应上进行评分,其计算效率直接影响 VDS-TTT 的实用性。
验证器类型详解
1. 结果奖励模型(Outcome Reward Model, ORM)
ORM 是最简单的验证器类型,只评估最终答案的正确性:
class SimpleORMVerifier:
"""
简单的结果奖励模型验证器。
只判断最终答案是否与标准答案匹配,
不关注推理过程。
"""
def __init__(self, normalize_fn=None):
"""
Args:
normalize_fn: 答案归一化函数(如去除空格、统一格式等)
"""
self.normalize_fn = normalize_fn or (lambda x: x.strip().lower())
def score(self, query: str, response: str, answer: str) -> float:
"""
计算置信度分数。
返回值:
- 1.0: 答案完全正确
- 0.0: 答案错误
- 0.5: 无法确定(边界情况)
"""
extracted = self._extract_answer(response)
if extracted is None:
return 0.2 # 无法提取答案
# 归一化比较
normalized_extracted = self.normalize_fn(extracted)
normalized_target = self.normalize_fn(answer)
if normalized_extracted == normalized_target:
return 1.0
elif self._is_close_match(normalized_extracted, normalized_target):
return 0.7 # 近似匹配
else:
return 0.0
def _extract_answer(self, response: str) -> str:
"""从响应中提取答案。"""
if '\boxed{' in response:
start = response.rfind('\boxed{') + 7
end = response.find('}', start)
return response[start:end].strip()
return response.strip().split('\n')[-1]
def _is_close_match(self, s1: str, s2: str) -> bool:
"""检查两个字符串是否近似匹配。"""
# 可以添加更多启发式规则
return False2. 过程奖励模型(Process Reward Model, PRM)
PRM 是更高级的验证器,评估推理链中每个步骤的质量:
class PRMVerifier:
"""
基于过程奖励模型的验证器实现。
PRM 评估推理过程中每个步骤的质量,
这些中间分数聚合后作为整个响应的置信度。
"""
def __init__(self, prm_model: nn.Module):
self.prm = prm_model
def score(self, query: str, response: str, answer: str) -> float:
"""
计算响应的置信度分数。
策略:
1. 将响应分解为推理步骤
2. 对每个步骤使用 PRM 评分
3. 聚合步骤分数得到整体置信度
"""
steps = self._decompose_steps(response)
step_scores = []
for step in steps:
step_score = self.prm.score(query, step)
step_scores.append(step_score)
# 置信度 = 加权平均 + 终态检查
confidence = self._aggregate_scores(step_scores)
answer_score = self._verify_answer(response, answer)
return confidence * answer_score
def _decompose_steps(self, response: str) -> List[str]:
"""
将响应分解为推理步骤。
策略:
- 按换行符分割
- 识别推理关键词(如 "因为"、"所以"、"首先" 等)
- 识别公式和计算步骤
"""
# 基础分割:按段落
paragraphs = response.split('\n\n')
steps = []
for para in paragraphs:
para = para.strip()
if para:
steps.append(para)
return steps
def _aggregate_scores(self, scores: List[float]) -> float:
"""
聚合步骤分数得到整体置信度。
方法选择:
- 几何平均:对低分更敏感(一个错误步骤会导致整体低分)
- 算术平均:平衡考虑
- 最小值:严格标准
"""
if not scores:
return 0.0
# 几何平均:对低分敏感
product = 1.0
for s in scores:
product *= max(s, 1e-8)
geometric_mean = product ** (1.0 / len(scores))
# 算术平均
arithmetic_mean = sum(scores) / len(scores)
# 加权组合
return 0.6 * geometric_mean + 0.4 * arithmetic_mean
def _verify_answer(self, response: str, answer: str) -> float:
"""
验证最终答案。
如果最终答案与目标答案不匹配,即使推理过程正确也返回低分。
"""
extracted = self._extract_final_answer(response)
if extracted == answer:
return 1.0
return 0.3
def _extract_final_answer(self, response: str) -> str:
"""提取最终答案。"""
if '\boxed{' in response:
start = response.rfind('\boxed{') + 7
end = response.find('}', start)
return response[start:end].strip()
return ""3. LLM-as-Judge 验证器
在没有专用验证器的情况下,可以使用大语言模型本身作为验证器:
class LLMJudgeVerifier:
"""
使用 LLM 作为 Judge 的验证器。
适用于没有专用验证器的通用任务。
"""
def __init__(self, judge_model: nn.Module, prompt_template: str = None):
self.judge = judge_model
self.prompt_template = prompt_template or self._default_template()
def _default_template(self) -> str:
return """
你是一个答案验证专家。请评估以下回答的质量。
问题:{question}
回答:{response}
请给出 0-1 之间的置信度分数,其中:
- 1.0 表示回答完全正确
- 0.5 表示回答部分正确但有缺陷
- 0.0 表示回答完全错误
只回答一个数字,不要添加任何解释。
"""
def score(self, query: str, response: str, answer: str) -> float:
"""
使用 LLM 评估响应的置信度。
"""
prompt = self.prompt_template.format(
question=query,
response=response
)
# 使用 LLM 生成评估
output = self.judge.generate(prompt)
# 解析分数
try:
score = float(output.strip())
return max(0.0, min(1.0, score)) # 限制在 [0, 1] 范围
except ValueError:
return 0.5 # 解析失败时返回中性分数与过程奖励模型的关系
VDS-TTT 与 过程奖励模型(PRM) 有密切关系。两者都涉及对响应的逐步评估,但存在关键区别:
| 维度 | VDS-TTT 验证器 | 过程奖励模型 |
|---|---|---|
| 评估粒度 | 整个响应的置信度 | 每个推理步骤的质量 |
| 训练方式 | 预训练/微调 | 过程监督强化学习 |
| 应用场景 | 测试时样本选择 | 推理过程引导 |
VDS-TTT 可以结合 PRM 作为验证器,通过评估推理链的每一步来给出更可靠的置信度评分。
算法实现
PyTorch 实现
import torch
import torch.nn as nn
from typing import List, Tuple, Optional, Dict, Any
class VDSTTT:
"""
Verifier-Driven Sample Selection for Test-Time Training (VDS-TTT)
该实现展示 VDS-TTT 的核心逻辑,包括候选生成、验证器筛选和 LoRA 微调。
"""
def __init__(
self,
llm: nn.Module,
verifier: nn.Module,
lora_config: dict,
temperature: float = 1.0,
num_samples: int = 8,
confidence_threshold: float = 0.5,
lora_steps: int = 5,
learning_rate: float = 1e-4
):
"""
初始化 VDS-TTT 框架。
Args:
llm: 预训练 LLM 模型
verifier: 验证器模型(可以是 PRM 或专用验证器)
lora_config: LoRA 配置字典
temperature: 采样温度 T > 0
num_samples: 每个查询生成的候选数量 N
confidence_threshold: 置信度阈值 τ
lora_steps: 每个样本的 LoRA 微调步数 M
learning_rate: 学习率 η
"""
self.llm = llm
self.verifier = verifier
self.temperature = temperature
self.num_samples = num_samples
self.confidence_threshold = confidence_threshold
self.lora_steps = lora_steps
self.learning_rate = learning_rate
# 初始化 LoRA 适配器
self.lora_adapter = self._init_lora(lora_config)
# 冻结基础模型参数
self._freeze_base_model()
def _freeze_base_model(self):
"""冻结基础模型参数,仅训练 LoRA 适配器。"""
for param in self.llm.parameters():
param.requires_grad = False
def _init_lora(self, config: dict) -> nn.Module:
"""
初始化 LoRA 适配器层。
LoRA 的核心思想是将权重更新 ΔW 分解为两个低秩矩阵的乘积:
ΔW = B × A,其中 B ∈ R^{d×r},A ∈ R^{r×k},r << min(d, k)
"""
lora_layers = nn.ModuleDict()
# 在注意力层注入 LoRA
for name, module in self.llm.named_modules():
if any(target in name for target in ['q_proj', 'k_proj', 'v_proj', 'o_proj']):
in_features = module.in_features
out_features = module.out_features
rank = config.get('rank', 8)
# LoRA: ΔW = BA,其中 B ∈ R^{d×r},A ∈ R^{r×k}
lora_layers[f'{name}.lora_a'] = nn.Parameter(
torch.randn(in_features, rank) * 0.01
)
lora_layers[f'{name}.lora_b'] = nn.Parameter(
torch.zeros(rank, out_features)
)
return lora_layers
def generate_candidates(
self,
query: str,
max_new_tokens: int = 512
) -> List[Tuple[str, str]]:
"""
阶段一:候选生成
使用温度采样生成 N 个多样化的候选响应。
"""
candidates = []
for _ in range(self.num_samples):
# 温度采样:T > 0 时增加多样性,T = 0 时贪婪解码
inputs = self.llm.tokenizer(query, return_tensors='pt')
with torch.no_grad():
outputs = self.llm.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=self.temperature,
do_sample=True,
top_p=0.95
)
response = self.llm.tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = self._extract_answer(response)
candidates.append((response, answer))
return candidates
def _extract_answer(self, response: str) -> str:
"""从响应中提取最终答案。"""
if '\\boxed{' in response:
start = response.rfind('\\boxed{') + 7
end = response.find('}', start)
return response[start:end]
return response.strip().split('\n')[-1]
def score_candidates(
self,
query: str,
candidates: List[Tuple[str, str]]
) -> List[float]:
"""阶段二:验证器评分"""
scores = []
for response, answer in candidates:
score = self.verifier.score(query, response, answer)
scores.append(score)
return scores
def select_best_candidate(
self,
candidates: List[Tuple[str, str]],
scores: List[float]
) -> Optional[Tuple[str, str, float]]:
"""
置信度筛选
选择得分最高的候选,如果得分低于阈值则丢弃。
"""
if not scores:
return None
best_idx = max(range(len(scores)), key=lambda i: scores[i])
best_score = scores[best_idx]
best_response, best_answer = candidates[best_idx]
# 置信度阈值过滤
if best_score < self.confidence_threshold:
return None
return (best_response, best_answer, best_score)
def fine_tune_lora(
self,
query: str,
response: str
) -> float:
"""
阶段三:LoRA 微调
在选中的伪标签对上执行 M 步 LoRA 微调。
"""
full_text = f"{query}\n{response}"
inputs = self.llm.tokenizer(
full_text,
return_tensors='pt',
truncation=True,
max_length=2048
)
labels = inputs['input_ids'].clone()
outputs = self.llm.model(inputs['input_ids'])
logits = outputs.logits
# 移位计算损失
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
# 反向传播(仅更新 LoRA 参数)
loss.backward()
# 更新 LoRA 参数
with torch.no_grad():
for name, param in self.lora_adapter.named_parameters():
if param.grad is not None:
param -= self.learning_rate * param.grad
return loss.item()
def process_query(self, query: str) -> str:
"""处理单个查询的完整流程。"""
# 阶段一:生成候选
candidates = self.generate_candidates(query)
# 阶段二:评分和筛选
scores = self.score_candidates(query, candidates)
best = self.select_best_candidate(candidates, scores)
if best is None:
return self.generate_candidates(query, max_new_tokens=512)[0][0]
best_response, best_answer, best_score = best
# 阶段三:LoRA 微调
self.fine_tune_lora(query, best_response)
return best_responsePRM 验证器实现
class PRMVerifier:
"""
基于过程奖励模型的验证器实现。
PRM 评估推理过程中每个步骤的质量,
这些中间分数聚合后作为整个响应的置信度。
"""
def __init__(self, prm_model: nn.Module):
self.prm = prm_model
def score(self, query: str, response: str, answer: str) -> float:
"""
计算响应的置信度分数。
策略:
1. 将响应分解为推理步骤
2. 对每个步骤使用 PRM 评分
3. 聚合步骤分数得到整体置信度
"""
steps = self._decompose_steps(response)
step_scores = []
for step in steps:
step_score = self.prm.score(query, step)
step_scores.append(step_score)
# 置信度 = 加权平均 + 终态检查
confidence = self._aggregate_scores(step_scores)
answer_score = self._verify_answer(response, answer)
return confidence * answer_score
def _aggregate_scores(self, scores: List[float]) -> float:
"""使用几何平均聚合步骤分数。"""
if not scores:
return 0.0
product = 1.0
for s in scores:
product *= max(s, 1e-8)
return product ** (1.0 / len(scores))
def _decompose_steps(self, response: str) -> List[str]:
"""将响应分解为推理步骤。"""
# 简化实现
return response.split('\n')实验结果
评估设置
VDS-TTT 在三个基准测试和三个前沿 LLM 上进行了评估:
| 基准测试 | 任务类型 | 示例 |
|---|---|---|
| MATH | 数学推理 | 竞赛数学问题 |
| GSM8K | 初等数学 | 小学应用题 |
| HotpotQA | 多跳推理 | 需要多个事实的问题 |
性能提升
实验结果表明,VDS-TTT 实现了显著的自我改进效果:
主要结果汇总:
| 方法 | MATH | GSM8K | HotpotQA | 平均提升 |
|---|---|---|---|---|
| 基线模型 | 42.3% | 78.5% | 61.2% | - |
| 温度采样 | 48.7% | 82.1% | 69.4% | +18.45% |
| 验证器方法 | 53.1% | 85.2% | 73.8% | +25.63% |
| VDS-TTT | 56.0% | 88.7% | 76.9% | +32.29% |
关键观察:
-
超越验证器基线:VDS-TTT 相比仅使用验证器的方法(不做测试时训练)有 6.66% 的绝对提升,这证明测试时微调确实能够进一步提升模型能力
-
跨模型一致性:在 GPT-4、Claude-3、LLaMA-3 三个模型上都观察到了类似的提升趋势,表明 VDS-TTT 的改进具有普遍性
-
任务难度相关性:提升幅度与任务难度正相关——越困难的任务(如 MATH 竞赛数学),VDS-TTT 的相对提升越大
| 方法 | 相对基线提升 | 绝对增益 |
|---|---|---|
| VDS-TTT | 32.29% | +6.66% vs 纯验证器方法 |
| 验证器方法(无 TTT) | 25.63% | 基线 |
| 温度采样 | 18.45% | - |
| 贪婪解码(基线) | 0% | 0% |
关键发现
- 置信度阈值的重要性:设置合适的 值至关重要
- 候选数量的权衡: 越大,正确率越高,但计算成本也增加
- 收敛性:性能稳定提升,通常在少数几次迭代后即可匹配甚至超越 Oracle 验证器
与相关工作的对比
与 测试时计算扩展 的关系
测试时计算扩展(Test-Time Compute Scaling)是近年来兴起的研究热点,旨在通过在推理阶段投入更多计算资源来提升模型性能。VDS-TTT 与这一范式有着深刻的联系,但也有本质的区别。
测试时计算扩展的核心思想:
传统的缩放定律关注的是训练时的计算资源()与模型性能的关系:
但这一定律暗示,当模型训练完成后,其能力就固定了。测试时计算扩展挑战了这一假设,主张在推理阶段也可以通过增加计算来突破模型能力上限。
VDS-TTT 的定位:
VDS-TTT 可以视为测试时计算扩展的一种参数化形式。具体而言:
| 维度 | 测试时计算扩展 | VDS-TTT |
|---|---|---|
| 核心思想 | 增加推理时的计算量(采样更多、更长的思维链) | 在推理时微调模型参数 |
| 方法 | 并行/串行采样 + 验证器选择 | 生成 + 验证器筛选 + 微调 |
| 适应性 | 固定模型,动态计算 | 动态模型,静态计算 |
| 成本结构 | 推理成本高(与采样数量线性相关) | 推理+轻量训练(可积累到模型中) |
| 效果持久性 | 单次推理有效 | 知识可累积,后续查询受益 |
VDS-TTT 的一个关键优势是知识的可累积性:通过测试时训练,模型学到的知识会沉淀到 LoRA 适配器中,使得后续类似的查询可以直接受益。这与传统的测试时计算扩展(每次推理都需要重新计算)形成鲜明对比。
| 维度 | 测试时计算扩展 | VDS-TTT |
|---|---|---|
| 核心思想 | 增加推理时的计算量 | 在推理时微调模型参数 |
| 适应性 | 固定模型,动态计算 | 动态模型,静态计算 |
VDS-TTT 可以视为测试时计算扩展的一种参数化形式。
与 TTT 架构 的区别
| 维度 | TTT 架构 | VDS-TTT |
|---|---|---|
| 核心创新 | 将隐藏状态建模为可学习的ML模型 | 使用验证器筛选伪标签 |
| 应用场景 | 通用序列建模 | 特定领域推理任务 |
与 PRM 的联系
VDS-TTT 的验证器组件与过程奖励模型高度相关,可以结合 PRM 实现更可靠的置信度评估。
伪标签质量管理
噪声注入与过滤机制
VDS-TTT 的一个关键设计决策是如何处理验证器可能产生的误判。验证器并非完美,它可能产生两种类型的错误:
- 假阳性:给错误响应高分,导致错误的伪标签被用于训练。这会直接污染训练数据,导致模型学到错误知识
- 假阴性:给正确响应低分,导致本可利用的样本被丢弃。这会降低训练效率,减缓模型改进速度
VDS-TTT 采用多层机制来缓解这些问题:
1. 置信度阈值 的自适应调整
固定的阈值可能在不同难度的查询上表现不一致。一种改进策略是动态调整阈值:
def adaptive_threshold(query_difficulty: float, base_threshold: float = 0.5) -> float:
"""
根据查询难度自适应调整置信度阈值。
策略:
- 简单查询:可以设置较高阈值,因为正确答案容易识别
- 困难查询:降低阈值,以保留更多潜在有用的样本
Args:
query_difficulty: 查询难度估计(0-1,越大越难)
base_threshold: 基础阈值
Returns:
调整后的阈值
"""
# 难度越高,阈值越低
# 使用线性插值
min_threshold = 0.2
max_threshold = 0.8
adjusted = base_threshold - 0.3 * (query_difficulty - 0.5)
return max(min_threshold, min(max_threshold, adjusted))
def estimate_difficulty(query: str) -> float:
"""
估计查询的难度。
简化实现:基于查询长度、包含的特定模式等特征估计难度。
实际应用中可以使用更复杂的分类器。
"""
# 长度特征
length_score = min(1.0, len(query) / 500)
# 关键词特征
hard_keywords = ["证明", "推导", "综合", "计算复杂度", "分析"]
easy_keywords = ["列出", "定义", "描述", "说明"]
hard_count = sum(1 for kw in hard_keywords if kw in query)
easy_count = sum(1 for kw in easy_keywords if kw in query)
keyword_score = (hard_count - easy_count) / max(1, hard_count + easy_count + 1)
# 综合难度
difficulty = 0.4 * length_score + 0.6 * (keyword_score + 1) / 2
return difficulty2. 多次验证交叉确认
对于高风险场景,可以采用多验证器交叉确认策略:
class MultiVerifierEnsemble:
"""
多验证器集成。
使用多个验证器进行交叉确认,提高筛选的可靠性。
"""
def __init__(self, verifiers: List):
self.verifiers = verifiers
def score_with_confidence(
self,
query: str,
response: str,
answer: str
) -> Tuple[float, float]:
"""
综合评分及置信度。
Returns:
(mean_score, confidence)
confidence 表示多个验证器之间的一致性
"""
import statistics
scores = []
for verifier in self.verifiers:
score = verifier.score(query, response, answer)
scores.append(score)
mean_score = statistics.mean(scores)
# 一致性:标准差越小,一致性越高
if len(scores) > 1:
std_score = statistics.stdev(scores)
confidence = 1.0 - min(1.0, std_score) # 转换为置信度
else:
confidence = 1.0
return mean_score, confidence3. 损失加权
根据验证器分数对训练样本进行加权,减少噪声样本的影响:
def compute_sample_weight(verifier_score: float, method: str = "linear") -> float:
"""
根据验证器分数计算样本权重。
策略:
- "linear": 线性加权 score
- "squared": 平方加权,强化高分样本
- "softmax": softmax 加权
"""
if method == "linear":
return verifier_score
elif method == "squared":
return verifier_score ** 2
elif method == "softmax":
import math
return 0.1 + 0.9 * (1 / (1 + math.exp(-10 * (verifier_score - 0.5))))
else:
return verifier_score4. 课程学习策略
按照查询难度逐步调整训练策略:
class CurriculumVDSTTT:
"""
课程学习版本的 VDS-TTT。
训练策略:
- 早期:只使用高置信度样本(简单查询)
- 中期:逐步降低阈值,包含中等难度样本
- 后期:使用全部样本,包括一些边缘案例
"""
def __init__(self, base_vds_ttt):
self.base = base_vds_ttt
self.phase = 0 # 0: 早期, 1: 中期, 2: 后期
def get_current_threshold(self) -> float:
"""根据当前阶段返回合适的阈值。"""
thresholds = {
0: 0.8, # 早期:严格筛选
1: 0.6, # 中期:适度放宽
2: 0.4 # 后期:保留更多样本
}
return thresholds[self.phase]
def update_phase(self, iteration: int, total_iterations: int):
"""更新训练阶段。"""
progress = iteration / total_iterations
if progress < 0.3:
self.phase = 0
elif progress < 0.7:
self.phase = 1
else:
self.phase = 2实践指南与最佳实践
超参数敏感性
VDS-TTT 的性能对以下超参数敏感:
| 参数 | 建议范围 | 影响 |
|---|---|---|
temperature | 0.5 - 1.0 | 控制候选多样性 |
num_samples (N) | 4 - 16 | 更多样本提高正确率,但增加延迟 |
confidence_threshold (τ) | 0.3 - 0.7 | 平衡样本质量和数量 |
lora_rank | 4 - 64 | 影响拟合能力和效率 |
lora_steps (M) | 1 - 10 | 更多步骤提高效果,但增加计算 |
温度参数 的调优建议
温度参数控制采样的随机性,是候选多样性的关键因素:
| 温度范围 | 特点 | 适用场景 |
|---|---|---|
| 低多样性,高确定性 | 需要稳定输出的场景 | |
| 适度多样性 | 平衡场景(推荐起点) | |
| 高多样性,低确定性 | 探索性任务,复杂推理 |
调优建议:从 开始,根据任务复杂度调整。对于需要多种解题思路的数学问题,可以尝试 ;对于事实性问答,可以降低到 。
候选数量 的权衡
候选数量直接影响正确响应的覆盖率:
- 边际效益递减: 从 4 增加到 8 通常带来显著提升,但 从 16 增加到 32 的收益较小
- 计算成本线性增长: 每增加一倍,计算时间大约增加一倍
- 推荐配置:
- 资源受限:
- 平衡配置:(推荐)
- 高精度需求:
领域适配性
VDS-TTT 的一个关键限制是依赖领域特定的验证器:
- 数学任务:可以使用 MATH-Verifier、Lean 证明助手等
- 编程任务:可以使用单元测试、执行反馈
- 通用任务:可以使用 LLM-as-Judge、PRM 等
计算成本分析
VDS-TTT 的计算成本包括:
- 候选生成: 次前向传播
- 验证器评分: 次验证器前向传播
- LoRA 微调: 步梯度更新
总成本约为基线推理的 倍。
伪标签质量管理
噪声过滤机制
VDS-TTT 采用多层机制处理验证器的误判问题:
1. 自适应阈值调整
def adaptive_threshold(query_difficulty: float, base_threshold: float = 0.5) -> float:
"""
根据查询难度自适应调整置信度阈值。
策略:
- 简单查询:设置较高阈值
- 困难查询:降低阈值
"""
min_threshold = 0.2
max_threshold = 0.8
adjusted = base_threshold - 0.3 * (query_difficulty - 0.5)
return max(min_threshold, min(max_threshold, adjusted))2. 损失加权
根据验证器分数对训练样本进行加权:
def compute_sample_weight(verifier_score: float, method: str = "linear") -> float:
"""
根据验证器分数计算样本权重。
"""
if method == "linear":
return verifier_score
elif method == "squared":
return verifier_score ** 2
return verifier_score3. 课程学习策略
按照查询难度逐步调整训练策略:
class CurriculumVDSTTT:
"""
课程学习版本的 VDS-TTT。
"""
def __init__(self, base_vds_ttt):
self.base = base_vds_ttt
self.phase = 0
def get_current_threshold(self) -> float:
thresholds = {0: 0.8, 1: 0.6, 2: 0.4}
return thresholds[self.phase]应用场景与局限性
适用场景
VDS-TTT 特别适合以下场景:
- 分布偏移明显:测试数据与训练数据分布差异大
- 存在可靠验证器:可以获取任务特定的验证信号
- 计算资源充足:能够支持测试时的额外计算
- 样本级适应:需要针对单个查询进行快速适应
当前局限性
尽管 VDS-TTT 在多个方面展现出优势,但其仍然存在一些局限性:
1. 验证器依赖:这是 VDS-TTT 最主要的局限。框架的效果很大程度上取决于验证器的质量。在没有可靠验证器的领域(如开放式问答、创意写作等),VDS-TTT 的应用会受到严重限制。
2. 累积漂移风险:随着测试实例的积累,连续的测试时训练可能导致模型逐渐偏离原始分布。这种现象类似于神经网络训练中的「灾难性遗忘」,可能导致模型在某些类型的查询上性能下降。
3. 冷启动问题:在刚开始应用 VDS-TTT 时,模型的初始性能可能较低,因为还没有积累足够的训练样本。这对于需要即时高性能的应用场景是一个挑战。
4. 超参敏感性:VDS-TTT 涉及多个超参数的协调调整,包括温度、候选数量、阈值、LoRA 配置等。这增加了实际部署的复杂度。
5. 计算成本:虽然相比全量微调已经大大降低,但 VDS-TTT 仍然需要额外的推理和训练开销。在延迟敏感的应用场景中,这可能是一个问题。
6. 领域泛化:VDS-TTT 针对特定领域训练的知识可能难以泛化到其他领域。每个新领域通常都需要重新设计验证器并调整超参数。
未来改进方向
针对上述局限性,可以考虑以下改进方向:
- 通用验证器:开发更通用的验证器,减少对领域特定知识的依赖
- 漂移检测与纠正:引入分布漂移检测机制,自动触发模型重置或纠正
- 元学习:使用元学习方法加速冷启动,让模型能够快速适应新领域
- 自适应超参数:开发自动化的超参数调整机制,降低部署复杂度
- 知识蒸馏:将测试时学到的知识蒸馏回更轻量的模型,减少推理开销
参考文献
相关主题
Footnotes
-
Moradi, M., Amer, H., Mudur, S., Zhang, W., Liu, Y., & Ahmed, W. (2025). Continuous Self-Improvement of Large Language Models by Test-time Training with Verifier-Driven Sample Selection. NeurIPS 2025 Workshop: AI That Keeps Up (CCFM). arXiv:2505.19475. ↩
-
Xiao, G., & Snoek, C. (2024). Test-time adaptation for handling domain shifts. ICML 2024. ↩