测试时学习(TLM)——大语言模型的测试时域适应

概述

测试时学习(Test-Time Learning,TTL)是一种新兴的范式,旨在使大语言模型(Large Language Models,LLM)能够在推理阶段动态适应目标领域,无需依赖标注数据即可提升模型性能。TLM由华南理工大学的研究团队提出,发表于ICML 2025。1

传统的LLM训练范式遵循”预训练-微调”模式,这种模式存在以下固有问题:

  1. 标注数据依赖:微调需要大量标注数据,在专业领域(如医学、法律、金融)获取成本高昂
  2. 分布偏移:训练数据与实际部署环境的数据分布存在差异,导致性能下降
  3. 知识遗忘:微调过程中可能覆盖预训练阶段积累的通用知识
  4. 计算效率:每次部署新领域都需要完整训练,资源消耗巨大

TLM通过输入困惑度最小化(Input Perplexity Minimization)作为优化目标,利用无标注测试数据实现LLM的自监督域适应,有效解决了上述问题。


问题背景:分布偏移与域适应挑战

分布偏移的定义

在机器学习中,分布偏移(Distribution Shift)指模型训练数据与测试数据的概率分布不一致:

对于LLM而言,这种偏移主要表现为:

偏移类型描述示例
协变量偏移输入分布变化通用语料 → 医学文献
先验偏移标签分布变化日常对话 → 专业问答
概念偏移语义关系变化旧知识 → 新术语

专业领域的挑战

当LLM部署在专业领域时,面临的核心问题包括:

  1. 领域术语缺失:预训练语料中未覆盖的专业术语
  2. 表达模式差异:专业领域的行文风格与通用文本显著不同
  3. 知识时效性:预训练知识可能已过时
  4. 稀缺标注数据:专业领域缺乏大规模高质量标注数据

这些问题导致通用LLM在专业任务上的表现远逊于专门训练的模型。TLM的核心贡献在于:无需任何标注数据,仅利用无标注的测试数据即可实现有效的域适应


TLM核心方法论

核心洞察:输入-输出困惑度关联

TLM的理论基础源于一个关键发现:LLM对输入的困惑度与对输出的困惑度具有高度相关性

对于给定问答对 ,定义:

输入困惑度(Input Perplexity):

输出困惑度(Output Perplexity):

其中 为模型参数, 分别为输入和输出的token数量。

核心观察:降低输入困惑度 等价于降低输出困惑度

优化目标:输入困惑度最小化

基于上述洞察,TLM将测试时学习形式化为输入困惑度最小化问题:

其中 为测试时更新的模型参数, 为无标注测试数据分布。

理论保证:最小化输入困惑度等价于最大化模型对测试数据分布的拟合能力,从而间接提升输出预测质量。


样本高效学习策略

高困惑度样本的信息价值

TLM的第二个关键发现是:高困惑度样本对模型更新的贡献大于低困惑度样本

这一发现具有直觉上的解释:

  1. 高困惑度 = 低拟合度:模型对高困惑度样本的拟合程度低,说明该样本包含模型尚未学习的新信息
  2. 低困惑度 = 已掌握知识:低困惑度样本对应模型已经较好掌握的模式,对更新贡献有限
  3. 负迁移风险:对已掌握样本过度训练可能导致对其他领域知识的遗忘

基于困惑度的样本加权

TLM采用困惑度阈值过滤指数加权相结合的策略:

其中:

符号含义
样本 的权重
温度参数(默认0.1)
样本 的困惑度
困惑度阈值(默认3.0)
指示函数

样本选择机制

# 样本选择伪代码
def select_samples(logits, labels, threshold=3.0, lamb=0.1):
    """
    基于困惑度选择高价值样本
    
    Args:
        logits: 模型输出logits
        labels: 真实标签
        threshold: 困惑度阈值
        lamb: 加权温度参数
    
    Returns:
        mask: 样本选择掩码
        weights: 样本权重
    """
    # 计算交叉熵(等价于困惑度)
    cross_entropy = cal_cross_entropy(logits, labels)
    
    # 筛选高于阈值的样本
    mask = cross_entropy > threshold
    
    # 指数加权
    weights = lamb * torch.exp(cross_entropy - threshold)
    weights = weights * mask
    
    return mask, weights

轻量化参数更新:LoRA的应用

灾难性遗忘问题

在测试时学习场景中,全参数微调(Full Fine-tuning)面临严重的灾难性遗忘(Catastrophic Forgetting)问题:

  • 测试时数据量有限,容易过拟合
  • 参数更新可能破坏预训练知识
  • 推理延迟增加,影响实时应用

LoRA解决方案

TLM采用低秩适配(Low-Rank Adaptation,LoRA)替代全参数更新,其核心思想是将权重更新 分解为两个低秩矩阵的乘积:

其中:

  • :预训练权重(冻结)
  • :可训练的低秩适配矩阵
  • :秩(通常为4-64)

LoRA抗遗忘的理论解释

LoRA之所以能有效缓解灾难性遗忘,原因在于:

  1. 冻结预训练权重:原始模型的知识得以保留
  2. 低秩约束:限制参数更新空间,避免剧烈波动
  3. 正交性:新增参数与原始参数空间的干扰最小化

实验表明,LoRA在连续微调任务中的遗忘率仅为0.6%,而全参数微调的遗忘率高达19.9%2


AdaptEval基准测试

基准设计

TLM团队构建了AdaptEval基准,专门用于评估LLM的测试时学习能力。该基准包含三个子集:

子集领域任务类型数据规模
DomainBench地理、医学、金融、农业领域问答各5000条
InstructionBench多领域指令跟随指令执行10000条
ReasoningBench数学、逻辑推理推理问答各3000条

评估指标

AdaptEval采用多种评估指标:

  1. ROUGE-L:衡量生成文本与参考文本的n-gram重叠
  2. BERTScore:基于语义嵌入的相似度度量
  3. 精确匹配准确率:对于有确定性答案的任务
  4. 困惑度改善率:域适应前后的困惑度变化

实验结果

TLM在AdaptEval上的核心实验结果:

数据集原始LLMTLM提升幅度
地理问答45.2%56.8%+25.7%
医学问答38.7%48.3%+24.8%
金融分析42.1%51.9%+23.3%
GSM8K数学52.3%63.1%+20.7%
平均提升--≥20%

实验结果表明,TLM在所有测试领域均实现至少20%的性能提升


算法流程

离线测试时学习

离线设置(Offline TTL)适用于有完整测试数据集的场景:

输入:无标注测试数据集 D_test,模型 M,预训练权重 W_0
输出:域适应后的模型 M'

1. 初始化 M,复制预训练权重 W_0
2. for epoch in range(E):
3.     for batch in D_test:
4.         # 前向传播
5.         logits = M(batch.input_ids)
6.         
7.         # 计算困惑度
8.         pp = cal_perplexity(logits, batch.labels)
9.         
10.        # 样本选择与加权
11.        mask, weights = select_samples(pp, threshold, lambda)
12.        
13.        # 计算加权KL散度损失
14.        loss = weighted_kl_divergence(logits, batch.labels, weights, mask)
15.        
16.        # 反向传播更新LoRA参数
17.        loss.backward()
18.        optimizer.step()
19.        optimizer.zero_grad()
20.
21. return M'

在线测试时学习

在线设置(Online TTL)适用于流式数据场景,边推理边适应:

输入:流式数据样本 x_i,模型 M
输出:更新后的模型

1. for each sample x in stream:
2.     # 使用当前模型预测
3.     y_pred = M.generate(x)
4.     
5.     # 计算输入困惑度
6.     pp = cal_input_perplexity(M, x)
7.     
8.     if pp > threshold:
9.         # 高困惑度样本触发更新
10.        loss = input_perplexity_loss(M, x)
11.        loss.backward()
12.        update_lora_parameters(M)
13.
14. return M

PyTorch实现

核心组件实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
class TLMTrainer:
    """
    Test-Time Learning (TLM) Trainer for LLMs
    
    核心思想:通过最小化输入困惑度实现自监督域适应
    """
    
    def __init__(
        self,
        model_name: str,
        lora_rank: int = 8,
        lora_alpha: int = 16,
        perplexity_threshold: float = 3.0,
        lambda_weight: float = 0.1,
        learning_rate: float = 5e-5,
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        
        # 冻结预训练权重
        for param in self.model.parameters():
            param.requires_grad = False
            
        # 初始化LoRA
        self._init_lora(lora_rank, lora_alpha)
        
        # TLM超参数
        self.ppl_threshold = perplexity_threshold
        self.lambda_weight = lambda_weight
        self.optimizer = torch.optim.AdamW(
            self.get_lora_parameters(), 
            lr=learning_rate
        )
        
        # 设备配置
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
    
    def _init_lora(self, rank: int, alpha: int):
        """初始化LoRA适配器"""
        from peft import get_peft_model, LoraConfig, TaskType
        
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=rank,
            lora_alpha=alpha,
            lora_dropout=0.1,
            target_modules=["q_proj", "v_proj"],  # 针对attention层
            bias="none",
        )
        self.model = get_peft_model(self.model, lora_config)
        self.model.print_trainable_parameters()
    
    def get_lora_parameters(self):
        """获取LoRA可训练参数"""
        return filter(lambda p: p.requires_grad, self.model.parameters())
    
    def calculate_perplexity(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        计算句子级别的困惑度
        
        Args:
            logits: [batch_size, seq_len, vocab_size]
            labels: [batch_size, seq_len]
        
        Returns:
            sentence_ppl: [batch_size]
        """
        # 移位以对齐logits和labels
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        
        # 创建掩码(忽略padding)
        mask = (shift_labels != -100).float()
        
        # 计算token级别的负对数似然
        loss = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            reduction="none"
        )
        
        # 掩码并求平均
        loss = loss.view(shift_logits.size(0), -1)
        masked_loss = loss * mask
        sentence_loss = masked_loss.sum(dim=-1) / mask.sum(dim=-1)
        
        # 转换为困惑度
        sentence_ppl = torch.exp(sentence_loss)
        
        return sentence_ppl
    
    def select_samples(
        self, 
        perplexity: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        基于困惑度选择高价值样本
        
        Args:
            perplexity: [batch_size] 句子困惑度
        
        Returns:
            mask: 选择掩码
            weights: 样本权重
        """
        # 阈值过滤:只选择高困惑度样本
        mask = (perplexity > self.ppl_threshold).float()
        
        # 指数加权
        weights = self.lambda_weight * torch.exp(perplexity - self.ppl_threshold)
        weights = weights * mask
        
        return mask, weights
    
    def compute_kl_divergence(
        self,
        logits_new: torch.Tensor,
        logits_old: torch.Tensor,
        labels: torch.Tensor
    ) -> torch.Tensor:
        """
        计算KL散度损失
        
        使用参考模型(冻结)和当前模型的输出分布计算KL散度
        """
        shift_logits_new = logits_new[..., :-1, :].contiguous()
        shift_logits_old = logits_old[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        
        mask = (shift_labels != -100).float()
        
        # 计算新模型的log softmax
        log_prob_new = F.log_softmax(shift_logits_new, dim=-1)
        
        # 计算参考模型的softmax
        prob_old = F.softmax(shift_logits_old, dim=-1).detach()
        
        # KL散度:KL(old || new) = sum(old * (log(old) - log(new)))
        kl_div = prob_old * (torch.log(prob_old + 1e-8) - log_prob_new)
        kl_div = kl_div.sum(dim=-1)  # 在vocab维度求和
        
        # 应用掩码
        masked_kl = kl_div * mask
        sentence_kl = masked_kl.sum(dim=-1) / mask.sum(dim=-1)
        
        return sentence_kl
    
    def forward_pass(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor,
        setting: str = "offline"
    ) -> Tuple[torch.Tensor, Dict]:
        """
        前向传播与损失计算
        
        Args:
            input_ids: 输入token IDs
            attention_mask: 注意力掩码
            labels: 标签
            setting: "offline" 或 "online"
        
        Returns:
            loss: 总损失
            metrics: 诊断指标
        """
        batch_size = input_ids.size(0)
        
        if setting == "offline":
            # 离线设置:使用冻结模型计算参考困惑度
            with torch.no_grad():
                # 禁用LoRA获取原始模型输出
                original_output = self.model.forward(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                )
                ref_logits = original_output.logits
                ref_perplexity = self.calculate_perplexity(ref_logits, labels)
            
            # 启用LoRA获取当前模型输出
            adapted_output = self.model.forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            adapted_logits = adapted_output.logits
        
        else:  # online
            # 在线设置:使用当前模型计算参考
            adapted_output = self.model.forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            adapted_logits = adapted_output.logits
            ref_perplexity = self.calculate_perplexity(adapted_logits, labels)
        
        # 样本选择
        mask, weights = self.select_samples(ref_perplexity)
        
        # 计算KL散度损失
        kl_div = self.compute_kl_divergence(adapted_logits, ref_logits, labels)
        
        # 加权求和
        weighted_kl = kl_div * weights
        if mask.sum() > 0:
            loss = weighted_kl.sum() / mask.sum()
        else:
            loss = weighted_kl.mean()
        
        # 诊断指标
        metrics = {
            "loss": loss.item(),
            "avg_perplexity": ref_perplexity.mean().item(),
            "selected_ratio": mask.sum().item() / batch_size,
            "avg_weight": weights.mean().item(),
        }
        
        return loss, metrics
    
    def train_step(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor,
        setting: str = "offline"
    ) -> Dict:
        """
        单步训练
        
        Args:
            input_ids: 输入token IDs
            attention_mask: 注意力掩码
            labels: 标签
            setting: "offline" 或 "online"
        
        Returns:
            metrics: 训练指标
        """
        self.model.train()
        self.optimizer.zero_grad()
        
        # 前向传播
        loss, metrics = self.forward_pass(
            input_ids, attention_mask, labels, setting
        )
        
        # 反向传播
        loss.backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(
            self.get_lora_parameters(), 
            max_norm=1.0
        )
        
        # 更新参数
        self.optimizer.step()
        
        return metrics
 
 
class TLMInference:
    """
    TLM推理引擎:支持流式推理和批量推理
    """
    
    def __init__(self, trainer: TLMTrainer):
        self.trainer = trainer
        self.tokenizer = trainer.tokenizer
        self.model = trainer.model
    
    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 512,
        temperature: float = 0.0,
    ) -> str:
        """
        生成文本
        
        Args:
            prompt: 输入提示
            max_new_tokens: 最大生成token数
            temperature: 采样温度
        
        Returns:
            生成的文本
        """
        inputs = self.tokenizer(
            prompt, 
            return_tensors="pt"
        ).to(self.model.device)
        
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=temperature > 0,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )
        
        generated_text = self.tokenizer.decode(
            outputs[0][inputs["input_ids"].size(1):],
            skip_special_tokens=True
        )
        
        return generated_text
    
    @torch.no_grad()
    def batch_generate(
        self,
        prompts: list,
        max_new_tokens: int = 512,
        temperature: float = 0.0,
    ) -> list:
        """批量生成"""
        inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
        ).to(self.model.device)
        
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=temperature > 0,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )
        
        results = []
        input_len = inputs["input_ids"].size(1)
        
        for i, output in enumerate(outputs):
            generated = self.tokenizer.decode(
                output[input_len:],
                skip_special_tokens=True
            )
            results.append(generated)
        
        return results

使用示例

from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
def main():
    # 初始化TLM训练器
    trainer = TLMTrainer(
        model_name="meta-llama/Llama-3-8B-Instruct",
        lora_rank=8,
        lora_alpha=16,
        perplexity_threshold=3.0,
        lambda_weight=0.1,
        learning_rate=5e-5,
    )
    
    # 准备数据(无标注测试数据)
    from datasets import load_dataset
    
    dataset = load_dataset("your-domain-dataset")
    
    def preprocess(example):
        return tokenizer(example["text"], truncation=True, max_length=512)
    
    dataset = dataset.map(preprocess, batched=True)
    
    # 离线训练
    print("Starting Offline TTL Training...")
    for epoch in range(3):
        for batch in dataloader:
            metrics = trainer.train_step(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"],
                setting="offline"
            )
            print(f"Epoch {epoch}: Loss={metrics['loss']:.4f}, "
                  f"Selected={metrics['selected_ratio']:.2%}")
    
    # 推理
    inference = TLMInference(trainer)
    
    prompt = "请解释量子计算的基本原理:"
    response = inference.generate(prompt)
    print(f"Generated: {response}")
 
 
if __name__ == "__main__":
    main()

与相关方法的关系

与测试时计算扩展的区别

测试时计算扩展(Test-Time Compute Scaling)通过增加推理时的计算资源(如更多采样、更深的思考链)来提升模型能力。而TLM的核心区别在于:

维度测试时计算扩展TLM
优化对象推理过程模型参数
数据依赖无需额外数据无标注测试数据
计算开销每次推理增加仅适应时增加
知识获取推理时计算隐式学习域知识

与参数高效微调的关系

TLM使用LoRA作为参数高效微调技术,适用于测试时学习场景。这种设计体现了PEFT概述中讨论的高效微调思想。

与持续学习的关系

TLM与LLM持续学习共享缓解灾难性遗忘的核心挑战。TLM通过冻结预训练权重和低秩约束,有效保护了原始知识。


实验分析

域适应效果

在多个专业领域的实验结果表明,TLM能显著提升LLM的域适应能力:

数据集        | 原始模型 | TLM     | 提升
--------------|----------|---------|------
地理问答      | 45.2%    | 56.8%   | +25.7%
医学问答      | 38.7%    | 48.3%   | +24.8%
金融分析      | 42.1%    | 51.9%   | +23.3%
农业知识      | 51.3%    | 62.7%   | +22.2%
GSM8K数学    | 52.3%    | 63.1%   | +20.7%

消融实验

关键组件的消融实验结果:

组件变体性能贡献
样本选择无选择50.2%baseline
样本选择阈值过滤54.8%+4.6%
样本选择阈值+加权56.8%+2.0%
参数更新全参数55.1%-1.7% (遗忘)
参数更新LoRA56.8%+1.7%

计算效率

TLM的计算开销分析:

指标数值
LoRA参数量~0.1% 原模型
单样本适应延迟~50ms
内存增量~500MB
无标注数据需求5000条

应用场景

医疗领域

在医疗问答场景中,TLM可以帮助模型快速适应:

  • 医学术语理解
  • 临床诊断推理
  • 药物相互作用查询

法律领域

  • 法律条文检索
  • 案例分析推理
  • 合同条款解析

金融领域

  • 财务报表分析
  • 市场趋势预测
  • 风险评估

科学研究

  • 论文摘要生成
  • 实验数据分析
  • 跨学科知识整合

局限性

尽管TLM展现出优异性能,仍存在以下局限:

  1. 计算开销:每次域适应仍需一定的梯度计算
  2. 阈值敏感性:性能对超参数(, )敏感
  3. 分布假设:假设高困惑度样本更具信息性,可能不适用于所有场景
  4. 长期适应:多次迭代适应可能导致域过拟合

总结

TLM(Test-Time Learning for LLMs)提出了一种创新的自监督域适应范式,通过以下核心机制实现LLM的动态域适应:

  1. 输入困惑度最小化:利用输入-输出困惑度的强相关性作为优化目标
  2. 样本高效学习:主动选择高困惑度样本进行训练,提高学习效率
  3. LoRA参数更新:轻量级参数更新缓解灾难性遗忘
  4. AdaptEval基准:标准化的评估框架促进领域研究

实验证明,TLM在多个专业领域实现至少20%的性能提升,为LLM的部署应用提供了新的技术路径。


参考文献


相关词条

Footnotes

  1. Hu, J., Zhang, Z., Chen, G., Wen, X., Shuai, C., Luo, W., Xiao, B., Li, Y., & Tan, M. (2025). Test-Time Learning for Large Language Models. Proceedings of the 42nd International Conference on Machine Learning (ICML 2025), 24823-24849. https://proceedings.mlr.press/v267/hu25z.html

  2. Zhao, S., Wei, J., & Wang, Y. (2024). Analyzing and Reducing Catastrophic Forgetting in Parameter Efficient Tuning. arXiv:2402.18865. https://arxiv.org/abs/2402.18865