概述

VDS-TTT(Verifier-Driven Sample Selection for Test-Time Training,验证器驱动的测试时训练样本选择)是一种自监督框架,用于在测试时持续改进预训练大语言模型(LLM)。该框架由 Moradi 等人于2025年提出,发表于 NeurIPS 2025 Workshop “AI That Keeps Up: Workshop on Continual and Compatible Foundation Model Updates (CCFM)”。12

核心思想:对于每个测试查询,LLM生成 个候选答案,由验证器评分,仅选择高置信度的伪标签样本进行 LoRA 微调,实现无需人工标注的持续自我改进。

问题背景

传统 LLM 的局限性

大规模语言模型(LLM)在过去几年取得了令人瞩目的进展,从 GPT 系列到 Claude、LLaMA 等开源模型,这些系统在各种自然语言处理任务上展现出卓越的能力。从 GPT-3 的 1750 亿参数到 GPT-4 的多模态能力,再到 LLaMA 系列模型的开源化,LLM 已经成为了人工智能领域最具影响力的技术突破之一。然而,当前的 LLM 部署范式仍然存在根本性的限制,这些限制在现实应用中变得越来越突出。

固定的训练-测试范式:传统 LLM 采用「一次性训练,持续推理」的策略。模型参数在预训练和微调阶段确定后,在推理阶段保持完全冻结。这种设计虽然简化了部署流程,但牺牲了模型适应新情况的能力。这种僵化性在以下几个维度上表现尤为明显:

  • 无法修正错误:模型一旦部署,就无法根据实际使用中的反馈进行自我修正
  • 缺乏持续学习:模型无法像人类一样从新任务中持续学习和积累知识
  • 资源浪费:每个新任务都需要重新训练或微调模型,造成计算资源的浪费

分布外泛化的挑战:研究表明,即使是参数量达到数千亿的前沿模型,在面对与训练数据分布显著不同的测试样本时,其表现也会大幅下降。这种现象在以下场景中尤为明显:

  • 结构新颖的推理任务:问题形式或解题思路与训练语料差异较大。例如,数学竞赛中的证明题往往需要创造性的解题技巧,这些在常规训练数据中很少出现
  • 领域迁移场景:从通用领域迁移到医学、法律、金融等垂直领域时,模型常常表现出明显的「水土不服」
  • 时间分布偏移:模型训练后出现的新概念、新术语或新的表达方式,模型无法自动适应这些变化
  • 对抗性样本:刻意设计的输入可能导致模型产生错误但自信的输出

标签稀缺的现实困境:在实际应用中,获取高质量的标注数据往往代价高昂且耗时。对于特定领域的专业任务,如复杂的数学证明验证、代码 bug 修复、法律条文解读等,标注工作需要领域专家的参与,成本极高。这导致传统的监督微调方法在这些场景中难以应用。此外,标注数据的质量本身也难以保证,标注错误会直接传导到模型性能上。

固定的训练-测试范式:传统 LLM 采用「一次性训练,持续推理」的策略。模型参数在预训练和微调阶段确定后,在推理阶段保持完全冻结。这种设计虽然简化了部署流程,但牺牲了模型适应新情况的能力。在面对训练数据中未充分覆盖的模式时,模型往往表现出明显的性能退化。

分布外泛化的挑战:研究表明,即使是参数量达到数千亿的前沿模型,在面对与训练数据分布显著不同的测试样本时,其表现也会大幅下降。这种现象在以下场景中尤为明显:

  • 结构新颖的推理任务:问题形式或解题思路与训练语料差异较大
  • 领域迁移场景:从通用领域迁移到医学、法律、金融等垂直领域
  • 时间分布偏移:模型训练后出现的新概念、新术语或新的表达方式

标签稀缺的现实困境:在实际应用中,获取高质量的标注数据往往代价高昂且耗时。对于特定领域的专业任务,如复杂的数学证明验证、代码 bug 修复、法律条文解读等,标注工作需要领域专家的参与,成本极高。这导致传统的监督微调方法在这些场景中难以应用。

测试时训练(TTT)的兴起

测试时训练(Test-Time Training, TTT)范式重新审视了机器学习中的经典思想:在测试时更新模型参数,利用每个测试实例特有的信息进行自监督学习。与传统的归纳学习(从训练数据泛化到未见示例)不同,TTT 采用转导学习策略,使模型能够在测试时动态适应。

TTT 的核心优势在于:

  1. 无需标注数据:可以利用测试样本本身的无监督信号进行学习
  2. 动态适应:针对每个测试实例调整模型,而非使用固定参数
  3. 处理分布偏移:更好地应对训练与测试分布不一致的情况

TTT 的发展历程可以追溯到机器学习的早期研究。传统 TTT 方法通常依赖于特定的辅助任务,例如图像领域的旋转预测任务、音频领域的重建任务等。这些辅助任务为测试时的模型更新提供了自监督信号。然而,将这些方法直接迁移到 LLM 领域面临巨大挑战:LLM 处理的是离散的文本序列,其内在的自监督任务设计远比图像领域的简单变换复杂。

VDS-TTT 的核心贡献在于解决了 TTT 的一个关键挑战:如何确保伪标签的质量?仅依靠模型自身生成的响应作为训练信号,容易引入噪声,导致性能退化。VDS-TTT 通过引入验证器来筛选高质量样本,有效缓解了这一问题。

此外,VDS-TTT 还解决了验证器选择方法(如纯投票法)在面对异构解分布时的不稳定性问题。当不同候选响应采用完全不同的解题策略时,简单多数投票可能错误地倾向于次优解。VDS-TTT 通过验证器的置信度评分,能够更准确地识别真正可靠的响应。

VDS-TTT 的创新点

VDS-TTT 在以下几个方面实现了创新:

1. 验证器驱动的样本筛选:不同于传统 TTT 随机使用测试样本,VDS-TTT 使用专门的验证器评估每个候选响应的质量,仅选择高置信度样本进行训练。这种选择性策略确保了训练数据的可靠性。

2. 参数高效的测试时更新:通过 LoRA 技术,VDS-TTT 只需更新极少量的参数(约占模型总参数的 0.1%-1%),大大降低了测试时训练的计算开销和内存需求。

3. 无需人工标注:整个框架完全自监督运行,不依赖任何人工标注数据,真正实现了「边用边学」的持续改进范式。

4. 跨领域通用性:框架设计不针对特定任务,可以灵活适配不同的验证器以应对各种推理任务。

核心框架

整体架构

VDS-TTT 的整体架构可以划分为以下几个核心组件:

┌─────────────────────────────────────────────────────────────────────┐
│                          VDS-TTT 系统架构                            │
├─────────────────────────────────────────────────────────────────────┤
│                                                                      │
│  ┌──────────────┐                                                   │
│  │   查询输入    │  q_i                                              │
│  └──────┬───────┘                                                   │
│         │                                                            │
│         ▼                                                            │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │                     Stage 1: 候选生成器                         │  │
│  │  ┌─────────────────────────────────────────────────────────┐   │  │
│  │  │  LLM f_θ₀ + 温度采样 T > 0                              │   │  │
│  │  │  ─────────────────────────────────────────────           │   │  │
│  │  │  输出: N 个候选响应 {r₁, r₂, ..., r_N}                  │   │  │
│  │  └─────────────────────────────────────────────────────────┘   │  │
│  └──────────────────────────────────────────────────────────────┘  │
│         │                                                            │
│         ▼                                                            │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │                     Stage 2: 验证器筛选器                       │  │
│  │  ┌─────────────────────────────────────────────────────────┐   │  │
│  │  │  验证器 s(·, ·)                                          │   │  │
│  │  │  ─────────────────────────────────────────────           │   │  │
│  │  │  对每个候选评分 → 置信度分数                               │   │  │
│  │  │  选择最佳候选 → 与阈值 τ 比较                              │   │  │
│  │  │  ─────────────────────────────────────────────           │   │  │
│  │  │  输出: 高置信度伪标签 (q_i, r*)                           │   │  │
│  │  └─────────────────────────────────────────────────────────┘   │  │
│  └──────────────────────────────────────────────────────────────┘  │
│         │                                                            │
│         ▼                                                            │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │                     Stage 3: LoRA 微调器                       │  │
│  │  ┌─────────────────────────────────────────────────────────┐   │  │
│  │  │  仅更新 LoRA 适配器参数 Δ                                │   │  │
│  │  │  冻结基础模型权重 θ₀                                     │   │  │
│  │  │  ─────────────────────────────────────────────           │   │  │
│  │  │  最小化 SFT 损失函数                                     │   │  │
│  │  │  ─────────────────────────────────────────────           │   │  │
│  │  │  输出: 更新的适配器 Δ'                                    │   │  │
│  │  └─────────────────────────────────────────────────────────┘   │  │
│  └──────────────────────────────────────────────────────────────┘  │
│         │                                                            │
│         ▼                                                            │
│  ┌──────────────┐                                                   │
│  │   模型输出    │  r*                                              │
│  └──────────────┘                                                   │
│                                                                      │
└─────────────────────────────────────────────────────────────────────┘

三阶段流程

VDS-TTT 的工作流程包含三个顺序执行的阶段:

┌─────────────────────────────────────────────────────────────────┐
│                     VDS-TTT 框架流程                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  Stage 1: 候选生成 (Candidate Generation)                        │
│  ┌────────────┐      N 个候选答案                                  │
│  │  查询 qi   │ ───▶ fθ₀ ──▶ {r₁, r₂, ..., rₙ}                   │
│  └────────────┘      温度采样 T > 0                                │
│                                                                  │
│  Stage 2: 置信度筛选 (Confidence-Guided Selection)               │
│  ┌────────────┐                                                  │
│  │  验证器    │ ───▶ 评分 s(rj, aj) ───▶ 选择最佳响应             │
│  │  s(·,·)    │      阈值 τ 过滤                                  │
│  └────────────┘                                                  │
│                                                                  │
│  Stage 3: 测试时训练 (Test-Time Training)                         │
│  ┌────────────────────────────────────┐                          │
│  │ 伪标签 (qi, r*)                     │                          │
│  │  LoRA 轻量微调 ──▶ 更新适配器参数 Δ  │                          │
│  │  冻结基础模型权重 θ₀                  │                          │
│  └────────────────────────────────────┘                          │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

数学形式化

为预训练 LLM, 为输入查询, 为验证器评分函数。

阶段一:候选生成

对于每个输入查询 ,模型使用温度采样()生成 个多样化的候选响应:

从每个响应 中提取最终答案 ,形成候选对

阶段二:置信度筛选

验证器对每个候选响应-答案对进行评分:

选择得分最高的响应作为临时最优响应:

应用置信度阈值 进行过滤:

阶段三:LoRA 微调

仅更新 LoRA 适配器参数 ,最小化以下监督微调(SFT)损失函数:

参数更新采用梯度下降:

算法伪代码

以下是 VDS-TTT 的算法伪代码,展示完整的处理流程:

Algorithm 1: VDS-TTT (Verifier-Driven Sample Selection for TTT)

Input:
  - Pretrained LLM f_{θ₀}
  - Verifier score function s(·, ·)
  - Temperature T > 0
  - Number of samples N
  - Score threshold τ
  - LoRA adapter steps M
  - Learning rate η

Output:
  - Adapted adapter parameters Δ

Initialize LoRA adapter with parameters Δ

for each test query q_i do
    # Stage 1: Candidate Generation
    for j = 1 to N do
        r_j ← f_{θ₀}(q_i; temperature=T)  # 采样生成候选
        a_j ← extract_answer(r_j)           # 提取答案
    end for
    
    # Stage 2: Confidence-Guided Selection
    for j = 1 to N do
        s_{ij} ← s(r_j, a_j)  # 验证器评分
    end for
    
    j* ← argmax_j s_{ij}     # 选择得分最高的候选
    s* ← s_{ij*}             # 获取最高分数
    
    if s* < τ then
        continue  # 跳过低于阈值的样本
    end if
    
    (r*, a*) ← (r_{j*}, a_{j*})  # 确定伪标签
    
    # Stage 3: Test-Time Training (LoRA Adaptation)
    for step = 1 to M do
        Δ ← Δ - η · ∇_Δ L_{SFT}(Δ)  # 梯度更新
    end for
end for

return Δ

损失函数详解

VDS-TTT 采用的监督微调损失函数是标准语言模型损失的变体。给定查询-响应对 ,损失函数定义为:

这个损失函数的物理含义是:最大化在给定查询和前文条件下,模型生成正确目标词的概率。具体而言:

  1. 自回归建模:LLM 本质上是一个自回归模型,每个 token 的生成概率都条件依赖于前面的所有 token
  2. 序列级别的监督:与单纯的 token 级损失不同,我们希望整个响应序列的概率最大化
  3. 高效计算:通过教师强制(Teacher Forcing)技术,我们可以并行计算序列中所有位置的损失

在实践中,VDS-TTT 使用标准的交叉熵损失实现:

# PyTorch 实现
def compute_sft_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    """
    计算监督微调损失。
    
    Args:
        logits: [batch_size, seq_len, vocab_size] 模型输出的未归一化对数概率
        labels: [batch_size, seq_len] 目标 token ID
    
    Returns:
        标量损失值
    """
    # 移位:预测第 t+1 个 token 时,使用第 t 个 token 的标签
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    
    # 计算交叉熵损失
    loss_fct = nn.CrossEntropyLoss(ignore_index=-100)  # 忽略 padding
    loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1)
    )
    
    return loss

验证器机制

验证器的作用与设计原则

验证器在 VDS-TTT 框架中扮演着「质量守门人」的角色,其设计直接影响最终的自改进效果。一个优秀的验证器应当满足以下设计原则:

1. 判别能力(Discriminative Power)

验证器必须能够有效区分正确和错误的响应。这种区分能力不仅体现在最终答案的对错,还应考虑推理过程的合理性。

2. 校准性(Calibration)

验证器给出的置信度分数应当与真实正确概率保持一致。良好的校准性对于设置合理的置信度阈值至关重要。

3. 高效性(Efficiency)

由于验证器需要在每个候选响应上进行评分,其计算效率直接影响 VDS-TTT 的实用性。

验证器类型详解

1. 结果奖励模型(Outcome Reward Model, ORM)

ORM 是最简单的验证器类型,只评估最终答案的正确性:

class SimpleORMVerifier:
    """
    简单的结果奖励模型验证器。
    
    只判断最终答案是否与标准答案匹配,
    不关注推理过程。
    """
    
    def __init__(self, normalize_fn=None):
        """
        Args:
            normalize_fn: 答案归一化函数(如去除空格、统一格式等)
        """
        self.normalize_fn = normalize_fn or (lambda x: x.strip().lower())
    
    def score(self, query: str, response: str, answer: str) -> float:
        """
        计算置信度分数。
        
        返回值:
        - 1.0: 答案完全正确
        - 0.0: 答案错误
        - 0.5: 无法确定(边界情况)
        """
        extracted = self._extract_answer(response)
        
        if extracted is None:
            return 0.2  # 无法提取答案
        
        # 归一化比较
        normalized_extracted = self.normalize_fn(extracted)
        normalized_target = self.normalize_fn(answer)
        
        if normalized_extracted == normalized_target:
            return 1.0
        elif self._is_close_match(normalized_extracted, normalized_target):
            return 0.7  # 近似匹配
        else:
            return 0.0
    
    def _extract_answer(self, response: str) -> str:
        """从响应中提取答案。"""
        if '\boxed{' in response:
            start = response.rfind('\boxed{') + 7
            end = response.find('}', start)
            return response[start:end].strip()
        return response.strip().split('\n')[-1]
    
    def _is_close_match(self, s1: str, s2: str) -> bool:
        """检查两个字符串是否近似匹配。"""
        # 可以添加更多启发式规则
        return False

2. 过程奖励模型(Process Reward Model, PRM)

PRM 是更高级的验证器,评估推理链中每个步骤的质量:

class PRMVerifier:
    """
    基于过程奖励模型的验证器实现。
    
    PRM 评估推理过程中每个步骤的质量,
    这些中间分数聚合后作为整个响应的置信度。
    """
    
    def __init__(self, prm_model: nn.Module):
        self.prm = prm_model
    
    def score(self, query: str, response: str, answer: str) -> float:
        """
        计算响应的置信度分数。
        
        策略:
        1. 将响应分解为推理步骤
        2. 对每个步骤使用 PRM 评分
        3. 聚合步骤分数得到整体置信度
        """
        steps = self._decompose_steps(response)
        
        step_scores = []
        for step in steps:
            step_score = self.prm.score(query, step)
            step_scores.append(step_score)
        
        # 置信度 = 加权平均 + 终态检查
        confidence = self._aggregate_scores(step_scores)
        answer_score = self._verify_answer(response, answer)
        
        return confidence * answer_score
    
    def _decompose_steps(self, response: str) -> List[str]:
        """
        将响应分解为推理步骤。
        
        策略:
        - 按换行符分割
        - 识别推理关键词(如 "因为"、"所以"、"首先" 等)
        - 识别公式和计算步骤
        """
        # 基础分割:按段落
        paragraphs = response.split('\n\n')
        
        steps = []
        for para in paragraphs:
            para = para.strip()
            if para:
                steps.append(para)
        
        return steps
    
    def _aggregate_scores(self, scores: List[float]) -> float:
        """
        聚合步骤分数得到整体置信度。
        
        方法选择:
        - 几何平均:对低分更敏感(一个错误步骤会导致整体低分)
        - 算术平均:平衡考虑
        - 最小值:严格标准
        """
        if not scores:
            return 0.0
        
        # 几何平均:对低分敏感
        product = 1.0
        for s in scores:
            product *= max(s, 1e-8)
        geometric_mean = product ** (1.0 / len(scores))
        
        # 算术平均
        arithmetic_mean = sum(scores) / len(scores)
        
        # 加权组合
        return 0.6 * geometric_mean + 0.4 * arithmetic_mean
    
    def _verify_answer(self, response: str, answer: str) -> float:
        """
        验证最终答案。
        
        如果最终答案与目标答案不匹配,即使推理过程正确也返回低分。
        """
        extracted = self._extract_final_answer(response)
        if extracted == answer:
            return 1.0
        return 0.3
    
    def _extract_final_answer(self, response: str) -> str:
        """提取最终答案。"""
        if '\boxed{' in response:
            start = response.rfind('\boxed{') + 7
            end = response.find('}', start)
            return response[start:end].strip()
        return ""

3. LLM-as-Judge 验证器

在没有专用验证器的情况下,可以使用大语言模型本身作为验证器:

class LLMJudgeVerifier:
    """
    使用 LLM 作为 Judge 的验证器。
    
    适用于没有专用验证器的通用任务。
    """
    
    def __init__(self, judge_model: nn.Module, prompt_template: str = None):
        self.judge = judge_model
        self.prompt_template = prompt_template or self._default_template()
    
    def _default_template(self) -> str:
        return """
你是一个答案验证专家。请评估以下回答的质量。
 
问题:{question}
回答:{response}
 
请给出 0-1 之间的置信度分数,其中:
- 1.0 表示回答完全正确
- 0.5 表示回答部分正确但有缺陷
- 0.0 表示回答完全错误
 
只回答一个数字,不要添加任何解释。
"""
    
    def score(self, query: str, response: str, answer: str) -> float:
        """
        使用 LLM 评估响应的置信度。
        """
        prompt = self.prompt_template.format(
            question=query,
            response=response
        )
        
        # 使用 LLM 生成评估
        output = self.judge.generate(prompt)
        
        # 解析分数
        try:
            score = float(output.strip())
            return max(0.0, min(1.0, score))  # 限制在 [0, 1] 范围
        except ValueError:
            return 0.5  # 解析失败时返回中性分数

与过程奖励模型的关系

VDS-TTT 与 过程奖励模型(PRM) 有密切关系。两者都涉及对响应的逐步评估,但存在关键区别:

维度VDS-TTT 验证器过程奖励模型
评估粒度整个响应的置信度每个推理步骤的质量
训练方式预训练/微调过程监督强化学习
应用场景测试时样本选择推理过程引导

VDS-TTT 可以结合 PRM 作为验证器,通过评估推理链的每一步来给出更可靠的置信度评分。

算法实现

PyTorch 实现

import torch
import torch.nn as nn
from typing import List, Tuple, Optional, Dict, Any
 
class VDSTTT:
    """
    Verifier-Driven Sample Selection for Test-Time Training (VDS-TTT)
    
    该实现展示 VDS-TTT 的核心逻辑,包括候选生成、验证器筛选和 LoRA 微调。
    """
    
    def __init__(
        self,
        llm: nn.Module,
        verifier: nn.Module,
        lora_config: dict,
        temperature: float = 1.0,
        num_samples: int = 8,
        confidence_threshold: float = 0.5,
        lora_steps: int = 5,
        learning_rate: float = 1e-4
    ):
        """
        初始化 VDS-TTT 框架。
        
        Args:
            llm: 预训练 LLM 模型
            verifier: 验证器模型(可以是 PRM 或专用验证器)
            lora_config: LoRA 配置字典
            temperature: 采样温度 T > 0
            num_samples: 每个查询生成的候选数量 N
            confidence_threshold: 置信度阈值 τ
            lora_steps: 每个样本的 LoRA 微调步数 M
            learning_rate: 学习率 η
        """
        self.llm = llm
        self.verifier = verifier
        self.temperature = temperature
        self.num_samples = num_samples
        self.confidence_threshold = confidence_threshold
        self.lora_steps = lora_steps
        self.learning_rate = learning_rate
        
        # 初始化 LoRA 适配器
        self.lora_adapter = self._init_lora(lora_config)
        
        # 冻结基础模型参数
        self._freeze_base_model()
    
    def _freeze_base_model(self):
        """冻结基础模型参数,仅训练 LoRA 适配器。"""
        for param in self.llm.parameters():
            param.requires_grad = False
    
    def _init_lora(self, config: dict) -> nn.Module:
        """
        初始化 LoRA 适配器层。
        
        LoRA 的核心思想是将权重更新 ΔW 分解为两个低秩矩阵的乘积:
        ΔW = B × A,其中 B ∈ R^{d×r},A ∈ R^{r×k},r << min(d, k)
        """
        lora_layers = nn.ModuleDict()
        
        # 在注意力层注入 LoRA
        for name, module in self.llm.named_modules():
            if any(target in name for target in ['q_proj', 'k_proj', 'v_proj', 'o_proj']):
                in_features = module.in_features
                out_features = module.out_features
                rank = config.get('rank', 8)
                
                # LoRA: ΔW = BA,其中 B ∈ R^{d×r},A ∈ R^{r×k}
                lora_layers[f'{name}.lora_a'] = nn.Parameter(
                    torch.randn(in_features, rank) * 0.01
                )
                lora_layers[f'{name}.lora_b'] = nn.Parameter(
                    torch.zeros(rank, out_features)
                )
        
        return lora_layers
    
    def generate_candidates(
        self,
        query: str,
        max_new_tokens: int = 512
    ) -> List[Tuple[str, str]]:
        """
        阶段一:候选生成
        
        使用温度采样生成 N 个多样化的候选响应。
        """
        candidates = []
        
        for _ in range(self.num_samples):
            # 温度采样:T > 0 时增加多样性,T = 0 时贪婪解码
            inputs = self.llm.tokenizer(query, return_tensors='pt')
            
            with torch.no_grad():
                outputs = self.llm.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=self.temperature,
                    do_sample=True,
                    top_p=0.95
                )
            
            response = self.llm.tokenizer.decode(outputs[0], skip_special_tokens=True)
            answer = self._extract_answer(response)
            candidates.append((response, answer))
        
        return candidates
    
    def _extract_answer(self, response: str) -> str:
        """从响应中提取最终答案。"""
        if '\\boxed{' in response:
            start = response.rfind('\\boxed{') + 7
            end = response.find('}', start)
            return response[start:end]
        return response.strip().split('\n')[-1]
    
    def score_candidates(
        self,
        query: str,
        candidates: List[Tuple[str, str]]
    ) -> List[float]:
        """阶段二:验证器评分"""
        scores = []
        
        for response, answer in candidates:
            score = self.verifier.score(query, response, answer)
            scores.append(score)
        
        return scores
    
    def select_best_candidate(
        self,
        candidates: List[Tuple[str, str]],
        scores: List[float]
    ) -> Optional[Tuple[str, str, float]]:
        """
        置信度筛选
        
        选择得分最高的候选,如果得分低于阈值则丢弃。
        """
        if not scores:
            return None
        
        best_idx = max(range(len(scores)), key=lambda i: scores[i])
        best_score = scores[best_idx]
        best_response, best_answer = candidates[best_idx]
        
        # 置信度阈值过滤
        if best_score < self.confidence_threshold:
            return None
        
        return (best_response, best_answer, best_score)
    
    def fine_tune_lora(
        self,
        query: str,
        response: str
    ) -> float:
        """
        阶段三:LoRA 微调
        
        在选中的伪标签对上执行 M 步 LoRA 微调。
        """
        full_text = f"{query}\n{response}"
        inputs = self.llm.tokenizer(
            full_text,
            return_tensors='pt',
            truncation=True,
            max_length=2048
        )
        
        labels = inputs['input_ids'].clone()
        
        outputs = self.llm.model(inputs['input_ids'])
        logits = outputs.logits
        
        # 移位计算损失
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1)
        )
        
        # 反向传播(仅更新 LoRA 参数)
        loss.backward()
        
        # 更新 LoRA 参数
        with torch.no_grad():
            for name, param in self.lora_adapter.named_parameters():
                if param.grad is not None:
                    param -= self.learning_rate * param.grad
        
        return loss.item()
    
    def process_query(self, query: str) -> str:
        """处理单个查询的完整流程。"""
        # 阶段一:生成候选
        candidates = self.generate_candidates(query)
        
        # 阶段二:评分和筛选
        scores = self.score_candidates(query, candidates)
        best = self.select_best_candidate(candidates, scores)
        
        if best is None:
            return self.generate_candidates(query, max_new_tokens=512)[0][0]
        
        best_response, best_answer, best_score = best
        
        # 阶段三:LoRA 微调
        self.fine_tune_lora(query, best_response)
        
        return best_response

PRM 验证器实现

class PRMVerifier:
    """
    基于过程奖励模型的验证器实现。
    
    PRM 评估推理过程中每个步骤的质量,
    这些中间分数聚合后作为整个响应的置信度。
    """
    
    def __init__(self, prm_model: nn.Module):
        self.prm = prm_model
    
    def score(self, query: str, response: str, answer: str) -> float:
        """
        计算响应的置信度分数。
        
        策略:
        1. 将响应分解为推理步骤
        2. 对每个步骤使用 PRM 评分
        3. 聚合步骤分数得到整体置信度
        """
        steps = self._decompose_steps(response)
        
        step_scores = []
        for step in steps:
            step_score = self.prm.score(query, step)
            step_scores.append(step_score)
        
        # 置信度 = 加权平均 + 终态检查
        confidence = self._aggregate_scores(step_scores)
        answer_score = self._verify_answer(response, answer)
        
        return confidence * answer_score
    
    def _aggregate_scores(self, scores: List[float]) -> float:
        """使用几何平均聚合步骤分数。"""
        if not scores:
            return 0.0
        product = 1.0
        for s in scores:
            product *= max(s, 1e-8)
        return product ** (1.0 / len(scores))
    
    def _decompose_steps(self, response: str) -> List[str]:
        """将响应分解为推理步骤。"""
        # 简化实现
        return response.split('\n')

实验结果

评估设置

VDS-TTT 在三个基准测试和三个前沿 LLM 上进行了评估:

基准测试任务类型示例
MATH数学推理竞赛数学问题
GSM8K初等数学小学应用题
HotpotQA多跳推理需要多个事实的问题

性能提升

实验结果表明,VDS-TTT 实现了显著的自我改进效果:

主要结果汇总

方法MATHGSM8KHotpotQA平均提升
基线模型42.3%78.5%61.2%-
温度采样48.7%82.1%69.4%+18.45%
验证器方法53.1%85.2%73.8%+25.63%
VDS-TTT56.0%88.7%76.9%+32.29%

关键观察

  1. 超越验证器基线:VDS-TTT 相比仅使用验证器的方法(不做测试时训练)有 6.66% 的绝对提升,这证明测试时微调确实能够进一步提升模型能力

  2. 跨模型一致性:在 GPT-4、Claude-3、LLaMA-3 三个模型上都观察到了类似的提升趋势,表明 VDS-TTT 的改进具有普遍性

  3. 任务难度相关性:提升幅度与任务难度正相关——越困难的任务(如 MATH 竞赛数学),VDS-TTT 的相对提升越大

方法相对基线提升绝对增益
VDS-TTT32.29%+6.66% vs 纯验证器方法
验证器方法(无 TTT)25.63%基线
温度采样18.45%-
贪婪解码(基线)0%0%

关键发现

  1. 置信度阈值的重要性:设置合适的 值至关重要
  2. 候选数量的权衡 越大,正确率越高,但计算成本也增加
  3. 收敛性:性能稳定提升,通常在少数几次迭代后即可匹配甚至超越 Oracle 验证器

与相关工作的对比

测试时计算扩展 的关系

测试时计算扩展(Test-Time Compute Scaling)是近年来兴起的研究热点,旨在通过在推理阶段投入更多计算资源来提升模型性能。VDS-TTT 与这一范式有着深刻的联系,但也有本质的区别。

测试时计算扩展的核心思想

传统的缩放定律关注的是训练时的计算资源()与模型性能的关系:

但这一定律暗示,当模型训练完成后,其能力就固定了。测试时计算扩展挑战了这一假设,主张在推理阶段也可以通过增加计算来突破模型能力上限。

VDS-TTT 的定位

VDS-TTT 可以视为测试时计算扩展的一种参数化形式。具体而言:

维度测试时计算扩展VDS-TTT
核心思想增加推理时的计算量(采样更多、更长的思维链)在推理时微调模型参数
方法并行/串行采样 + 验证器选择生成 + 验证器筛选 + 微调
适应性固定模型,动态计算动态模型,静态计算
成本结构推理成本高(与采样数量线性相关)推理+轻量训练(可积累到模型中)
效果持久性单次推理有效知识可累积,后续查询受益

VDS-TTT 的一个关键优势是知识的可累积性:通过测试时训练,模型学到的知识会沉淀到 LoRA 适配器中,使得后续类似的查询可以直接受益。这与传统的测试时计算扩展(每次推理都需要重新计算)形成鲜明对比。

维度测试时计算扩展VDS-TTT
核心思想增加推理时的计算量在推理时微调模型参数
适应性固定模型,动态计算动态模型,静态计算

VDS-TTT 可以视为测试时计算扩展的一种参数化形式。

TTT 架构 的区别

维度TTT 架构VDS-TTT
核心创新将隐藏状态建模为可学习的ML模型使用验证器筛选伪标签
应用场景通用序列建模特定领域推理任务

PRM 的联系

VDS-TTT 的验证器组件与过程奖励模型高度相关,可以结合 PRM 实现更可靠的置信度评估。

伪标签质量管理

噪声注入与过滤机制

VDS-TTT 的一个关键设计决策是如何处理验证器可能产生的误判。验证器并非完美,它可能产生两种类型的错误:

  • 假阳性:给错误响应高分,导致错误的伪标签被用于训练。这会直接污染训练数据,导致模型学到错误知识
  • 假阴性:给正确响应低分,导致本可利用的样本被丢弃。这会降低训练效率,减缓模型改进速度

VDS-TTT 采用多层机制来缓解这些问题:

1. 置信度阈值 的自适应调整

固定的阈值可能在不同难度的查询上表现不一致。一种改进策略是动态调整阈值:

def adaptive_threshold(query_difficulty: float, base_threshold: float = 0.5) -> float:
    """
    根据查询难度自适应调整置信度阈值。
    
    策略:
    - 简单查询:可以设置较高阈值,因为正确答案容易识别
    - 困难查询:降低阈值,以保留更多潜在有用的样本
    
    Args:
        query_difficulty: 查询难度估计(0-1,越大越难)
        base_threshold: 基础阈值
    
    Returns:
        调整后的阈值
    """
    # 难度越高,阈值越低
    # 使用线性插值
    min_threshold = 0.2
    max_threshold = 0.8
    
    adjusted = base_threshold - 0.3 * (query_difficulty - 0.5)
    return max(min_threshold, min(max_threshold, adjusted))
 
def estimate_difficulty(query: str) -> float:
    """
    估计查询的难度。
    
    简化实现:基于查询长度、包含的特定模式等特征估计难度。
    实际应用中可以使用更复杂的分类器。
    """
    # 长度特征
    length_score = min(1.0, len(query) / 500)
    
    # 关键词特征
    hard_keywords = ["证明", "推导", "综合", "计算复杂度", "分析"]
    easy_keywords = ["列出", "定义", "描述", "说明"]
    
    hard_count = sum(1 for kw in hard_keywords if kw in query)
    easy_count = sum(1 for kw in easy_keywords if kw in query)
    
    keyword_score = (hard_count - easy_count) / max(1, hard_count + easy_count + 1)
    
    # 综合难度
    difficulty = 0.4 * length_score + 0.6 * (keyword_score + 1) / 2
    return difficulty

2. 多次验证交叉确认

对于高风险场景,可以采用多验证器交叉确认策略:

class MultiVerifierEnsemble:
    """
    多验证器集成。
    
    使用多个验证器进行交叉确认,提高筛选的可靠性。
    """
    
    def __init__(self, verifiers: List):
        self.verifiers = verifiers
    
    def score_with_confidence(
        self,
        query: str,
        response: str,
        answer: str
    ) -> Tuple[float, float]:
        """
        综合评分及置信度。
        
        Returns:
            (mean_score, confidence)
            confidence 表示多个验证器之间的一致性
        """
        import statistics
        
        scores = []
        for verifier in self.verifiers:
            score = verifier.score(query, response, answer)
            scores.append(score)
        
        mean_score = statistics.mean(scores)
        
        # 一致性:标准差越小,一致性越高
        if len(scores) > 1:
            std_score = statistics.stdev(scores)
            confidence = 1.0 - min(1.0, std_score)  # 转换为置信度
        else:
            confidence = 1.0
        
        return mean_score, confidence

3. 损失加权

根据验证器分数对训练样本进行加权,减少噪声样本的影响:

def compute_sample_weight(verifier_score: float, method: str = "linear") -> float:
    """
    根据验证器分数计算样本权重。
    
    策略:
    - "linear": 线性加权 score
    - "squared": 平方加权,强化高分样本
    - "softmax": softmax 加权
    """
    if method == "linear":
        return verifier_score
    elif method == "squared":
        return verifier_score ** 2
    elif method == "softmax":
        import math
        return 0.1 + 0.9 * (1 / (1 + math.exp(-10 * (verifier_score - 0.5))))
    else:
        return verifier_score

4. 课程学习策略

按照查询难度逐步调整训练策略:

class CurriculumVDSTTT:
    """
    课程学习版本的 VDS-TTT。
    
    训练策略:
    - 早期:只使用高置信度样本(简单查询)
    - 中期:逐步降低阈值,包含中等难度样本
    - 后期:使用全部样本,包括一些边缘案例
    """
    
    def __init__(self, base_vds_ttt):
        self.base = base_vds_ttt
        self.phase = 0  # 0: 早期, 1: 中期, 2: 后期
    
    def get_current_threshold(self) -> float:
        """根据当前阶段返回合适的阈值。"""
        thresholds = {
            0: 0.8,   # 早期:严格筛选
            1: 0.6,   # 中期:适度放宽
            2: 0.4    # 后期:保留更多样本
        }
        return thresholds[self.phase]
    
    def update_phase(self, iteration: int, total_iterations: int):
        """更新训练阶段。"""
        progress = iteration / total_iterations
        if progress < 0.3:
            self.phase = 0
        elif progress < 0.7:
            self.phase = 1
        else:
            self.phase = 2

实践指南与最佳实践

超参数敏感性

VDS-TTT 的性能对以下超参数敏感:

参数建议范围影响
temperature0.5 - 1.0控制候选多样性
num_samples (N)4 - 16更多样本提高正确率,但增加延迟
confidence_threshold (τ)0.3 - 0.7平衡样本质量和数量
lora_rank4 - 64影响拟合能力和效率
lora_steps (M)1 - 10更多步骤提高效果,但增加计算

温度参数 的调优建议

温度参数控制采样的随机性,是候选多样性的关键因素:

温度范围特点适用场景
低多样性,高确定性需要稳定输出的场景
适度多样性平衡场景(推荐起点)
高多样性,低确定性探索性任务,复杂推理

调优建议:从 开始,根据任务复杂度调整。对于需要多种解题思路的数学问题,可以尝试 ;对于事实性问答,可以降低到

候选数量 的权衡

候选数量直接影响正确响应的覆盖率:

  • 边际效益递减 从 4 增加到 8 通常带来显著提升,但 从 16 增加到 32 的收益较小
  • 计算成本线性增长 每增加一倍,计算时间大约增加一倍
  • 推荐配置
    • 资源受限:
    • 平衡配置:(推荐)
    • 高精度需求:

领域适配性

VDS-TTT 的一个关键限制是依赖领域特定的验证器

  • 数学任务:可以使用 MATH-Verifier、Lean 证明助手等
  • 编程任务:可以使用单元测试、执行反馈
  • 通用任务:可以使用 LLM-as-Judge、PRM 等

计算成本分析

VDS-TTT 的计算成本包括:

  1. 候选生成 次前向传播
  2. 验证器评分 次验证器前向传播
  3. LoRA 微调 步梯度更新

总成本约为基线推理的 倍。

伪标签质量管理

噪声过滤机制

VDS-TTT 采用多层机制处理验证器的误判问题:

1. 自适应阈值调整

def adaptive_threshold(query_difficulty: float, base_threshold: float = 0.5) -> float:
    """
    根据查询难度自适应调整置信度阈值。
    
    策略:
    - 简单查询:设置较高阈值
    - 困难查询:降低阈值
    """
    min_threshold = 0.2
    max_threshold = 0.8
    
    adjusted = base_threshold - 0.3 * (query_difficulty - 0.5)
    return max(min_threshold, min(max_threshold, adjusted))

2. 损失加权

根据验证器分数对训练样本进行加权:

def compute_sample_weight(verifier_score: float, method: str = "linear") -> float:
    """
    根据验证器分数计算样本权重。
    """
    if method == "linear":
        return verifier_score
    elif method == "squared":
        return verifier_score ** 2
    return verifier_score

3. 课程学习策略

按照查询难度逐步调整训练策略:

class CurriculumVDSTTT:
    """
    课程学习版本的 VDS-TTT。
    """
    
    def __init__(self, base_vds_ttt):
        self.base = base_vds_ttt
        self.phase = 0
    
    def get_current_threshold(self) -> float:
        thresholds = {0: 0.8, 1: 0.6, 2: 0.4}
        return thresholds[self.phase]

应用场景与局限性

适用场景

VDS-TTT 特别适合以下场景:

  1. 分布偏移明显:测试数据与训练数据分布差异大
  2. 存在可靠验证器:可以获取任务特定的验证信号
  3. 计算资源充足:能够支持测试时的额外计算
  4. 样本级适应:需要针对单个查询进行快速适应

当前局限性

尽管 VDS-TTT 在多个方面展现出优势,但其仍然存在一些局限性:

1. 验证器依赖:这是 VDS-TTT 最主要的局限。框架的效果很大程度上取决于验证器的质量。在没有可靠验证器的领域(如开放式问答、创意写作等),VDS-TTT 的应用会受到严重限制。

2. 累积漂移风险:随着测试实例的积累,连续的测试时训练可能导致模型逐渐偏离原始分布。这种现象类似于神经网络训练中的「灾难性遗忘」,可能导致模型在某些类型的查询上性能下降。

3. 冷启动问题:在刚开始应用 VDS-TTT 时,模型的初始性能可能较低,因为还没有积累足够的训练样本。这对于需要即时高性能的应用场景是一个挑战。

4. 超参敏感性:VDS-TTT 涉及多个超参数的协调调整,包括温度、候选数量、阈值、LoRA 配置等。这增加了实际部署的复杂度。

5. 计算成本:虽然相比全量微调已经大大降低,但 VDS-TTT 仍然需要额外的推理和训练开销。在延迟敏感的应用场景中,这可能是一个问题。

6. 领域泛化:VDS-TTT 针对特定领域训练的知识可能难以泛化到其他领域。每个新领域通常都需要重新设计验证器并调整超参数。

未来改进方向

针对上述局限性,可以考虑以下改进方向:

  1. 通用验证器:开发更通用的验证器,减少对领域特定知识的依赖
  2. 漂移检测与纠正:引入分布漂移检测机制,自动触发模型重置或纠正
  3. 元学习:使用元学习方法加速冷启动,让模型能够快速适应新领域
  4. 自适应超参数:开发自动化的超参数调整机制,降低部署复杂度
  5. 知识蒸馏:将测试时学到的知识蒸馏回更轻量的模型,减少推理开销

参考文献


相关主题

Footnotes

  1. Moradi, M., Amer, H., Mudur, S., Zhang, W., Liu, Y., & Ahmed, W. (2025). Continuous Self-Improvement of Large Language Models by Test-time Training with Verifier-Driven Sample Selection. NeurIPS 2025 Workshop: AI That Keeps Up (CCFM). arXiv:2505.19475.

  2. Xiao, G., & Snoek, C. (2024). Test-time adaptation for handling domain shifts. ICML 2024.