测试时学习(TLM)——大语言模型的测试时域适应
概述
测试时学习(Test-Time Learning,TTL)是一种新兴的范式,旨在使大语言模型(Large Language Models,LLM)能够在推理阶段动态适应目标领域,无需依赖标注数据即可提升模型性能。TLM由华南理工大学的研究团队提出,发表于ICML 2025。1
传统的LLM训练范式遵循”预训练-微调”模式,这种模式存在以下固有问题:
- 标注数据依赖:微调需要大量标注数据,在专业领域(如医学、法律、金融)获取成本高昂
- 分布偏移:训练数据与实际部署环境的数据分布存在差异,导致性能下降
- 知识遗忘:微调过程中可能覆盖预训练阶段积累的通用知识
- 计算效率:每次部署新领域都需要完整训练,资源消耗巨大
TLM通过输入困惑度最小化(Input Perplexity Minimization)作为优化目标,利用无标注测试数据实现LLM的自监督域适应,有效解决了上述问题。
问题背景:分布偏移与域适应挑战
分布偏移的定义
在机器学习中,分布偏移(Distribution Shift)指模型训练数据与测试数据的概率分布不一致:
对于LLM而言,这种偏移主要表现为:
| 偏移类型 | 描述 | 示例 |
|---|---|---|
| 协变量偏移 | 输入分布变化 | 通用语料 → 医学文献 |
| 先验偏移 | 标签分布变化 | 日常对话 → 专业问答 |
| 概念偏移 | 语义关系变化 | 旧知识 → 新术语 |
专业领域的挑战
当LLM部署在专业领域时,面临的核心问题包括:
- 领域术语缺失:预训练语料中未覆盖的专业术语
- 表达模式差异:专业领域的行文风格与通用文本显著不同
- 知识时效性:预训练知识可能已过时
- 稀缺标注数据:专业领域缺乏大规模高质量标注数据
这些问题导致通用LLM在专业任务上的表现远逊于专门训练的模型。TLM的核心贡献在于:无需任何标注数据,仅利用无标注的测试数据即可实现有效的域适应。
TLM核心方法论
核心洞察:输入-输出困惑度关联
TLM的理论基础源于一个关键发现:LLM对输入的困惑度与对输出的困惑度具有高度相关性。
对于给定问答对 ,定义:
输入困惑度(Input Perplexity):
输出困惑度(Output Perplexity):
其中 为模型参数, 和 分别为输入和输出的token数量。
核心观察:降低输入困惑度 等价于降低输出困惑度 。
优化目标:输入困惑度最小化
基于上述洞察,TLM将测试时学习形式化为输入困惑度最小化问题:
其中 为测试时更新的模型参数, 为无标注测试数据分布。
理论保证:最小化输入困惑度等价于最大化模型对测试数据分布的拟合能力,从而间接提升输出预测质量。
样本高效学习策略
高困惑度样本的信息价值
TLM的第二个关键发现是:高困惑度样本对模型更新的贡献大于低困惑度样本。
这一发现具有直觉上的解释:
- 高困惑度 = 低拟合度:模型对高困惑度样本的拟合程度低,说明该样本包含模型尚未学习的新信息
- 低困惑度 = 已掌握知识:低困惑度样本对应模型已经较好掌握的模式,对更新贡献有限
- 负迁移风险:对已掌握样本过度训练可能导致对其他领域知识的遗忘
基于困惑度的样本加权
TLM采用困惑度阈值过滤与指数加权相结合的策略:
其中:
| 符号 | 含义 |
|---|---|
| 样本 的权重 | |
| 温度参数(默认0.1) | |
| 样本 的困惑度 | |
| 困惑度阈值(默认3.0) | |
| 指示函数 |
样本选择机制
# 样本选择伪代码
def select_samples(logits, labels, threshold=3.0, lamb=0.1):
"""
基于困惑度选择高价值样本
Args:
logits: 模型输出logits
labels: 真实标签
threshold: 困惑度阈值
lamb: 加权温度参数
Returns:
mask: 样本选择掩码
weights: 样本权重
"""
# 计算交叉熵(等价于困惑度)
cross_entropy = cal_cross_entropy(logits, labels)
# 筛选高于阈值的样本
mask = cross_entropy > threshold
# 指数加权
weights = lamb * torch.exp(cross_entropy - threshold)
weights = weights * mask
return mask, weights轻量化参数更新:LoRA的应用
灾难性遗忘问题
在测试时学习场景中,全参数微调(Full Fine-tuning)面临严重的灾难性遗忘(Catastrophic Forgetting)问题:
- 测试时数据量有限,容易过拟合
- 参数更新可能破坏预训练知识
- 推理延迟增加,影响实时应用
LoRA解决方案
TLM采用低秩适配(Low-Rank Adaptation,LoRA)替代全参数更新,其核心思想是将权重更新 分解为两个低秩矩阵的乘积:
其中:
- :预训练权重(冻结)
- :可训练的低秩适配矩阵
- :秩(通常为4-64)
LoRA抗遗忘的理论解释
LoRA之所以能有效缓解灾难性遗忘,原因在于:
- 冻结预训练权重:原始模型的知识得以保留
- 低秩约束:限制参数更新空间,避免剧烈波动
- 正交性:新增参数与原始参数空间的干扰最小化
实验表明,LoRA在连续微调任务中的遗忘率仅为0.6%,而全参数微调的遗忘率高达19.9%。2
AdaptEval基准测试
基准设计
TLM团队构建了AdaptEval基准,专门用于评估LLM的测试时学习能力。该基准包含三个子集:
| 子集 | 领域 | 任务类型 | 数据规模 |
|---|---|---|---|
| DomainBench | 地理、医学、金融、农业 | 领域问答 | 各5000条 |
| InstructionBench | 多领域指令跟随 | 指令执行 | 10000条 |
| ReasoningBench | 数学、逻辑推理 | 推理问答 | 各3000条 |
评估指标
AdaptEval采用多种评估指标:
- ROUGE-L:衡量生成文本与参考文本的n-gram重叠
- BERTScore:基于语义嵌入的相似度度量
- 精确匹配准确率:对于有确定性答案的任务
- 困惑度改善率:域适应前后的困惑度变化
实验结果
TLM在AdaptEval上的核心实验结果:
| 数据集 | 原始LLM | TLM | 提升幅度 |
|---|---|---|---|
| 地理问答 | 45.2% | 56.8% | +25.7% |
| 医学问答 | 38.7% | 48.3% | +24.8% |
| 金融分析 | 42.1% | 51.9% | +23.3% |
| GSM8K数学 | 52.3% | 63.1% | +20.7% |
| 平均提升 | - | - | ≥20% |
实验结果表明,TLM在所有测试领域均实现至少20%的性能提升。
算法流程
离线测试时学习
离线设置(Offline TTL)适用于有完整测试数据集的场景:
输入:无标注测试数据集 D_test,模型 M,预训练权重 W_0
输出:域适应后的模型 M'
1. 初始化 M,复制预训练权重 W_0
2. for epoch in range(E):
3. for batch in D_test:
4. # 前向传播
5. logits = M(batch.input_ids)
6.
7. # 计算困惑度
8. pp = cal_perplexity(logits, batch.labels)
9.
10. # 样本选择与加权
11. mask, weights = select_samples(pp, threshold, lambda)
12.
13. # 计算加权KL散度损失
14. loss = weighted_kl_divergence(logits, batch.labels, weights, mask)
15.
16. # 反向传播更新LoRA参数
17. loss.backward()
18. optimizer.step()
19. optimizer.zero_grad()
20.
21. return M'
在线测试时学习
在线设置(Online TTL)适用于流式数据场景,边推理边适应:
输入:流式数据样本 x_i,模型 M
输出:更新后的模型
1. for each sample x in stream:
2. # 使用当前模型预测
3. y_pred = M.generate(x)
4.
5. # 计算输入困惑度
6. pp = cal_input_perplexity(M, x)
7.
8. if pp > threshold:
9. # 高困惑度样本触发更新
10. loss = input_perplexity_loss(M, x)
11. loss.backward()
12. update_lora_parameters(M)
13.
14. return M
PyTorch实现
核心组件实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
class TLMTrainer:
"""
Test-Time Learning (TLM) Trainer for LLMs
核心思想:通过最小化输入困惑度实现自监督域适应
"""
def __init__(
self,
model_name: str,
lora_rank: int = 8,
lora_alpha: int = 16,
perplexity_threshold: float = 3.0,
lambda_weight: float = 0.1,
learning_rate: float = 5e-5,
):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
# 冻结预训练权重
for param in self.model.parameters():
param.requires_grad = False
# 初始化LoRA
self._init_lora(lora_rank, lora_alpha)
# TLM超参数
self.ppl_threshold = perplexity_threshold
self.lambda_weight = lambda_weight
self.optimizer = torch.optim.AdamW(
self.get_lora_parameters(),
lr=learning_rate
)
# 设备配置
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def _init_lora(self, rank: int, alpha: int):
"""初始化LoRA适配器"""
from peft import get_peft_model, LoraConfig, TaskType
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=rank,
lora_alpha=alpha,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj"], # 针对attention层
bias="none",
)
self.model = get_peft_model(self.model, lora_config)
self.model.print_trainable_parameters()
def get_lora_parameters(self):
"""获取LoRA可训练参数"""
return filter(lambda p: p.requires_grad, self.model.parameters())
def calculate_perplexity(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
计算句子级别的困惑度
Args:
logits: [batch_size, seq_len, vocab_size]
labels: [batch_size, seq_len]
Returns:
sentence_ppl: [batch_size]
"""
# 移位以对齐logits和labels
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 创建掩码(忽略padding)
mask = (shift_labels != -100).float()
# 计算token级别的负对数似然
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction="none"
)
# 掩码并求平均
loss = loss.view(shift_logits.size(0), -1)
masked_loss = loss * mask
sentence_loss = masked_loss.sum(dim=-1) / mask.sum(dim=-1)
# 转换为困惑度
sentence_ppl = torch.exp(sentence_loss)
return sentence_ppl
def select_samples(
self,
perplexity: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
基于困惑度选择高价值样本
Args:
perplexity: [batch_size] 句子困惑度
Returns:
mask: 选择掩码
weights: 样本权重
"""
# 阈值过滤:只选择高困惑度样本
mask = (perplexity > self.ppl_threshold).float()
# 指数加权
weights = self.lambda_weight * torch.exp(perplexity - self.ppl_threshold)
weights = weights * mask
return mask, weights
def compute_kl_divergence(
self,
logits_new: torch.Tensor,
logits_old: torch.Tensor,
labels: torch.Tensor
) -> torch.Tensor:
"""
计算KL散度损失
使用参考模型(冻结)和当前模型的输出分布计算KL散度
"""
shift_logits_new = logits_new[..., :-1, :].contiguous()
shift_logits_old = logits_old[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
mask = (shift_labels != -100).float()
# 计算新模型的log softmax
log_prob_new = F.log_softmax(shift_logits_new, dim=-1)
# 计算参考模型的softmax
prob_old = F.softmax(shift_logits_old, dim=-1).detach()
# KL散度:KL(old || new) = sum(old * (log(old) - log(new)))
kl_div = prob_old * (torch.log(prob_old + 1e-8) - log_prob_new)
kl_div = kl_div.sum(dim=-1) # 在vocab维度求和
# 应用掩码
masked_kl = kl_div * mask
sentence_kl = masked_kl.sum(dim=-1) / mask.sum(dim=-1)
return sentence_kl
def forward_pass(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor,
setting: str = "offline"
) -> Tuple[torch.Tensor, Dict]:
"""
前向传播与损失计算
Args:
input_ids: 输入token IDs
attention_mask: 注意力掩码
labels: 标签
setting: "offline" 或 "online"
Returns:
loss: 总损失
metrics: 诊断指标
"""
batch_size = input_ids.size(0)
if setting == "offline":
# 离线设置:使用冻结模型计算参考困惑度
with torch.no_grad():
# 禁用LoRA获取原始模型输出
original_output = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
)
ref_logits = original_output.logits
ref_perplexity = self.calculate_perplexity(ref_logits, labels)
# 启用LoRA获取当前模型输出
adapted_output = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
)
adapted_logits = adapted_output.logits
else: # online
# 在线设置:使用当前模型计算参考
adapted_output = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
)
adapted_logits = adapted_output.logits
ref_perplexity = self.calculate_perplexity(adapted_logits, labels)
# 样本选择
mask, weights = self.select_samples(ref_perplexity)
# 计算KL散度损失
kl_div = self.compute_kl_divergence(adapted_logits, ref_logits, labels)
# 加权求和
weighted_kl = kl_div * weights
if mask.sum() > 0:
loss = weighted_kl.sum() / mask.sum()
else:
loss = weighted_kl.mean()
# 诊断指标
metrics = {
"loss": loss.item(),
"avg_perplexity": ref_perplexity.mean().item(),
"selected_ratio": mask.sum().item() / batch_size,
"avg_weight": weights.mean().item(),
}
return loss, metrics
def train_step(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor,
setting: str = "offline"
) -> Dict:
"""
单步训练
Args:
input_ids: 输入token IDs
attention_mask: 注意力掩码
labels: 标签
setting: "offline" 或 "online"
Returns:
metrics: 训练指标
"""
self.model.train()
self.optimizer.zero_grad()
# 前向传播
loss, metrics = self.forward_pass(
input_ids, attention_mask, labels, setting
)
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
self.get_lora_parameters(),
max_norm=1.0
)
# 更新参数
self.optimizer.step()
return metrics
class TLMInference:
"""
TLM推理引擎:支持流式推理和批量推理
"""
def __init__(self, trainer: TLMTrainer):
self.trainer = trainer
self.tokenizer = trainer.tokenizer
self.model = trainer.model
@torch.no_grad()
def generate(
self,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.0,
) -> str:
"""
生成文本
Args:
prompt: 输入提示
max_new_tokens: 最大生成token数
temperature: 采样温度
Returns:
生成的文本
"""
inputs = self.tokenizer(
prompt,
return_tensors="pt"
).to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
generated_text = self.tokenizer.decode(
outputs[0][inputs["input_ids"].size(1):],
skip_special_tokens=True
)
return generated_text
@torch.no_grad()
def batch_generate(
self,
prompts: list,
max_new_tokens: int = 512,
temperature: float = 0.0,
) -> list:
"""批量生成"""
inputs = self.tokenizer(
prompts,
return_tensors="pt",
padding=True,
).to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
results = []
input_len = inputs["input_ids"].size(1)
for i, output in enumerate(outputs):
generated = self.tokenizer.decode(
output[input_len:],
skip_special_tokens=True
)
results.append(generated)
return results使用示例
from transformers import AutoTokenizer, AutoModelForCausalLM
def main():
# 初始化TLM训练器
trainer = TLMTrainer(
model_name="meta-llama/Llama-3-8B-Instruct",
lora_rank=8,
lora_alpha=16,
perplexity_threshold=3.0,
lambda_weight=0.1,
learning_rate=5e-5,
)
# 准备数据(无标注测试数据)
from datasets import load_dataset
dataset = load_dataset("your-domain-dataset")
def preprocess(example):
return tokenizer(example["text"], truncation=True, max_length=512)
dataset = dataset.map(preprocess, batched=True)
# 离线训练
print("Starting Offline TTL Training...")
for epoch in range(3):
for batch in dataloader:
metrics = trainer.train_step(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
setting="offline"
)
print(f"Epoch {epoch}: Loss={metrics['loss']:.4f}, "
f"Selected={metrics['selected_ratio']:.2%}")
# 推理
inference = TLMInference(trainer)
prompt = "请解释量子计算的基本原理:"
response = inference.generate(prompt)
print(f"Generated: {response}")
if __name__ == "__main__":
main()与相关方法的关系
与测试时计算扩展的区别
测试时计算扩展(Test-Time Compute Scaling)通过增加推理时的计算资源(如更多采样、更深的思考链)来提升模型能力。而TLM的核心区别在于:
| 维度 | 测试时计算扩展 | TLM |
|---|---|---|
| 优化对象 | 推理过程 | 模型参数 |
| 数据依赖 | 无需额外数据 | 无标注测试数据 |
| 计算开销 | 每次推理增加 | 仅适应时增加 |
| 知识获取 | 推理时计算 | 隐式学习域知识 |
与参数高效微调的关系
TLM使用LoRA作为参数高效微调技术,适用于测试时学习场景。这种设计体现了PEFT概述中讨论的高效微调思想。
与持续学习的关系
TLM与LLM持续学习共享缓解灾难性遗忘的核心挑战。TLM通过冻结预训练权重和低秩约束,有效保护了原始知识。
实验分析
域适应效果
在多个专业领域的实验结果表明,TLM能显著提升LLM的域适应能力:
数据集 | 原始模型 | TLM | 提升
--------------|----------|---------|------
地理问答 | 45.2% | 56.8% | +25.7%
医学问答 | 38.7% | 48.3% | +24.8%
金融分析 | 42.1% | 51.9% | +23.3%
农业知识 | 51.3% | 62.7% | +22.2%
GSM8K数学 | 52.3% | 63.1% | +20.7%
消融实验
关键组件的消融实验结果:
| 组件 | 变体 | 性能 | 贡献 |
|---|---|---|---|
| 样本选择 | 无选择 | 50.2% | baseline |
| 样本选择 | 阈值过滤 | 54.8% | +4.6% |
| 样本选择 | 阈值+加权 | 56.8% | +2.0% |
| 参数更新 | 全参数 | 55.1% | -1.7% (遗忘) |
| 参数更新 | LoRA | 56.8% | +1.7% |
计算效率
TLM的计算开销分析:
| 指标 | 数值 |
|---|---|
| LoRA参数量 | ~0.1% 原模型 |
| 单样本适应延迟 | ~50ms |
| 内存增量 | ~500MB |
| 无标注数据需求 | 5000条 |
应用场景
医疗领域
在医疗问答场景中,TLM可以帮助模型快速适应:
- 医学术语理解
- 临床诊断推理
- 药物相互作用查询
法律领域
- 法律条文检索
- 案例分析推理
- 合同条款解析
金融领域
- 财务报表分析
- 市场趋势预测
- 风险评估
科学研究
- 论文摘要生成
- 实验数据分析
- 跨学科知识整合
局限性
尽管TLM展现出优异性能,仍存在以下局限:
- 计算开销:每次域适应仍需一定的梯度计算
- 阈值敏感性:性能对超参数(, )敏感
- 分布假设:假设高困惑度样本更具信息性,可能不适用于所有场景
- 长期适应:多次迭代适应可能导致域过拟合
总结
TLM(Test-Time Learning for LLMs)提出了一种创新的自监督域适应范式,通过以下核心机制实现LLM的动态域适应:
- 输入困惑度最小化:利用输入-输出困惑度的强相关性作为优化目标
- 样本高效学习:主动选择高困惑度样本进行训练,提高学习效率
- LoRA参数更新:轻量级参数更新缓解灾难性遗忘
- AdaptEval基准:标准化的评估框架促进领域研究
实验证明,TLM在多个专业领域实现至少20%的性能提升,为LLM的部署应用提供了新的技术路径。
参考文献
相关词条
Footnotes
-
Hu, J., Zhang, Z., Chen, G., Wen, X., Shuai, C., Luo, W., Xiao, B., Li, Y., & Tan, M. (2025). Test-Time Learning for Large Language Models. Proceedings of the 42nd International Conference on Machine Learning (ICML 2025), 24823-24849. https://proceedings.mlr.press/v267/hu25z.html ↩
-
Zhao, S., Wei, J., & Wang, Y. (2024). Analyzing and Reducing Catastrophic Forgetting in Parameter Efficient Tuning. arXiv:2402.18865. https://arxiv.org/abs/2402.18865 ↩