TF-TTCL:LLM免训练测试时对比学习
概述
TF-TTCL(Training-Free Test-Time Contrastive Learning,arXiv:2604.13552)提出了一种无需梯度更新的测试时对比学习方法,使冻结的大语言模型能够在线自我改进。其核心框架是**“探索-反思-引导”(Explore-Reflect-Steer)**,从模型自身的推理经验中蒸馏对比监督信号。
核心贡献
- 完全免训练:无需梯度更新,无需外部知识
- 在线适应:在推理过程中实时改进
- 多代理协作:Teacher-Tutor-Student三重角色
- 跨任务泛化:闭集推理和开放评估均有效
问题背景
现有LLM的测试时适应方法存在局限:
| 方法 | 梯度更新 | 白盒访问 | 外部知识 | 效果 |
|---|---|---|---|---|
| Fine-tuning | ✅ | ✅ | 可用 | 好 |
| LoRA | ✅ | ✅ | 可用 | 好 |
| TLM | ✅ | ✅ | 不可用 | 中等 |
| TF-TTCL | ❌ | ❌ | ❌ | 好 |
方法详解
1. 问题形式化
给定测试分布 (训练分布),目标:
约束条件:
- 模型参数 冻结
- 在线单遍协议(严格测试时设置)
- 无外部反馈
2. 核心框架:Explore-Reflect-Steer
┌─────────────────────────────────────────────────────────────────┐
│ TF-TTCL 三阶段框架 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Stage 1: Semantic Query Augmentation (SQA) │
│ ┌────────────────────────────────────────────────────────┐ │
│ │ Teacher: 生成稳定锚点答案 │ │
│ │ Tutor: 重写查询生成语义变体 │ │
│ │ Student: 探索多样推理路径 │ │
│ └────────────────────────────────────────────────────────┘ │
│ ↓ │
│ Stage 2: Contrastive Experience Distillation (CED) │
│ ┌────────────────────────────────────────────────────────┐ │
│ │ 一致性划分候选集 │ │
│ │ Min-PPL选择正负样本 │ │
│ │ 蒸馏为显式规则 │ │
│ └────────────────────────────────────────────────────────┘ │
│ ↓ │
│ Stage 3: Contextual Rule Retrieval (CRR) │
│ ┌────────────────────────────────────────────────────────┐ │
│ │ 正负规则分离存储 │ │
│ │ 余弦相似度检索 │ │
│ │ 结构化上下文注入 │ │
│ └────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
Stage 1: 语义查询增强(SQA)
1.1 Teacher生成锚点
def teacher_generate(query, model, rules):
"""
Teacher: 使用贪婪解码生成高置信度锚点
"""
# 注入相关规则作为上下文
context = retrieve_relevant_rules(query, rules)
augmented_query = f"{context}\n\n问题: {query}"
# 贪婪解码(稳定输出)
answer = model.generate(
augmented_query,
strategy='greedy' # 确定性强
)
return answer关键:Teacher使用贪婪解码确保生成稳定、可信的锚点答案。
1.2 Tutor重写查询
def tutor_rewrite(query, model, temperature=0.7):
"""
Tutor: 使用温度采样重写查询
"""
prompt = f"""
请用不同的方式重写以下问题,保持语义不变:
问题: {query}
重写版本:
"""
# 温度采样生成语义变体
rewritten = model.generate(
prompt,
temperature=temperature,
max_tokens=100
)
return rewritten关键:温度0.7平衡多样性和语义一致性。
1.3 Student采样响应
def student_explore(rewritten_queries, model, rules, n_samples=5):
"""
Student: 对每个重写查询采样多个响应
"""
candidates = []
for q in rewritten_queries:
context = retrieve_relevant_rules(q, rules)
augmented_q = f"{context}\n\n问题: {q}"
# 采样n_samples个响应
for _ in range(n_samples):
response = model.generate(
augmented_q,
temperature=0.9,
do_sample=True
)
candidates.append(response)
return candidatesStage 2: 对比经验蒸馏(CED)
2.1 候选集划分
闭集推理任务(CRT):使用多数投票聚类
def partition_candidates_by_voting(candidates):
"""
使用多数投票划分正负候选集
"""
# 统计各类别票数
votes = {}
for cand in candidates:
pred = extract_answer(cand)
votes[pred] = votes.get(pred, 0) + 1
# 多数类为正,少数类为负
sorted_votes = sorted(votes.items(), key=lambda x: -x[1])
positive_class = sorted_votes[0][0]
Y_plus = [c for c in candidates if extract_answer(c) == positive_class]
Y_minus = [c for c in candidates if extract_answer(c) != positive_class]
return Y_plus, Y_minus开放评估任务(OET):使用嵌入相似度排序
def partition_candidates_by_similarity(candidates, teacher_answer):
"""
使用嵌入相似度划分候选集
"""
embeddings = [model.embed(cand) for cand in candidates]
teacher_emb = model.embed(teacher_answer)
# 计算与Teacher答案的相似度
similarities = [
cosine_sim(emb, teacher_emb)
for emb in embeddings
]
# 高相似度为正,低相似度为负
threshold = np.median(similarities)
Y_plus = [c for c, s in zip(candidates, similarities) if s >= threshold]
Y_minus = [c for c, s in zip(candidates, similarities) if s < threshold]
return Y_plus, Y_minus2.2 Min-PPL选择
**困惑度(PPL)**是选择正负样本的关键:
def min_ppl_selection(Y_plus, Y_minus, model):
"""
使用Min-PPL选择最可靠的样本
"""
# 计算所有候选的PPL
ppl_scores = {}
for cand in Y_plus + Y_minus:
ppl = compute_ppl(cand, model)
ppl_scores[cand] = ppl
# 正样本:PPL最低(最自信的正确答案)
y_pos = min(Y_plus, key=lambda c: ppl_scores[c])
# 负样本:PPL最低(最自信的错误答案)
y_neg = min(Y_minus, key=lambda c: ppl_scores[c])
return y_pos, y_neg
def compute_ppl(text, model):
"""
计算文本的困惑度
"""
tokens = model.tokenize(text)
log_probs = model.get_log_probs(tokens)
# 归一化的负对数似然
ppl = np.exp(-np.mean(log_probs))
return ppl2.3 规则蒸馏
def distill_rule(query, y_pos, y_neg, model):
"""
蒸馏为显式正负规则
"""
prompt = f"""
给定问题和两个答案:
问题: {query}
正确答案: {y_pos}
错误答案: {y_neg}
请总结:
1. 为什么{y_pos}是正确的(正规则)
2. {y_neg}哪里错了(负规则)
用一句话简洁回答。
"""
# 提取正负规则
r_pos = model.generate(f"正规则: 问题'{query}'的答案应该是...",
strategy='greedy')
r_neg = model.generate(f"负规则: 问题'{query}'不应该...",
strategy='greedy')
return r_pos, r_negStage 3: 上下文规则检索(CRR)
3.1 规则存储结构
class RuleRepository:
"""
分离存储正负规则的仓库
"""
def __init__(self):
self.pos_rules = [] # (rule, embedding, query_embedding)
self.neg_rules = []
self.max_size = 10000
def add(self, rule_pos, rule_neg, query_emb):
"""添加规则"""
rule_pos_emb = embed(rule_pos)
self.pos_rules.append({
'rule': rule_pos,
'embedding': rule_pos_emb,
'query_emb': query_emb
})
rule_neg_emb = embed(rule_neg)
self.neg_rules.append({
'rule': rule_neg,
'embedding': rule_neg_emb,
'query_emb': query_emb
})
# FIFO修剪
if len(self.pos_rules) > self.max_size:
self.pos_rules.pop(0)
if len(self.neg_rules) > self.max_size:
self.neg_rules.pop(0)3.2 余弦相似度检索
def retrieve_relevant_rules(query, repository, top_k=3):
"""
检索与当前查询最相关的正负规则
"""
query_emb = embed(query)
# 检索正规则
pos_scores = [
cosine_sim(query_emb, r['embedding'])
for r in repository.pos_rules
]
top_pos_indices = np.argsort(pos_scores)[-top_k:]
relevant_pos = [repository.pos_rules[i]['rule'] for i in top_pos_indices]
# 检索负规则
neg_scores = [
cosine_sim(query_emb, r['embedding'])
for r in repository.neg_rules
]
top_neg_indices = np.argsort(neg_scores)[-top_k:]
relevant_neg = [repository.neg_rules[i]['rule'] for i in top_neg_indices]
return relevant_pos, relevant_neg3.3 结构化上下文注入
def build_context(queries, rules_pos, rules_neg):
"""
构建结构化上下文
"""
context_parts = []
if rules_pos:
context_parts.append("【应该做的事】")
for r in rules_pos:
context_parts.append(f"- {r}")
if rules_neg:
context_parts.append("【不应该做的事】")
for r in rules_neg:
context_parts.append(f"- {r}")
if queries:
context_parts.append("【类似问题参考】")
for q in queries[-2:]: # 最近2个
context_parts.append(f"- {q}")
return "\n".join(context_parts)完整推理流程
def tf_ttcl_inference(query, model, repository):
"""
TF-TTCL 完整推理流程
"""
# 1. 检索规则
rules_pos, rules_neg = retrieve_relevant_rules(query, repository)
# 2. Teacher生成锚点
anchor = teacher_generate(query, model,
build_context(None, rules_pos, rules_neg))
candidates = [anchor]
# 3. Tutor重写查询
rewritten_queries = [tutor_rewrite(query, model)]
# 4. Student采样多个响应
candidates.extend(
student_explore(rewritten_queries, model,
build_context(None, rules_pos, rules_neg),
n_samples=5)
)
# 5. 划分正负候选集
Y_plus, Y_minus = partition_candidates_by_voting(candidates)
# 6. Min-PPL选择
y_pos, y_neg = min_ppl_selection(Y_plus, Y_minus, model)
# 7. 蒸馏规则
r_pos, r_neg = distill_rule(query, y_pos, y_neg, model)
# 8. 更新仓库
repository.add(r_pos, r_neg, embed(query))
# 9. 返回正样本作为最终答案
return y_pos实验结果
闭集推理任务(Llama-3.1-8B-Instruct)
| 方法 | GSM8k | MATH-500 | AIME24 | Minerva | 平均 |
|---|---|---|---|---|---|
| Base LLM | 82.49 | 49.20 | 3.33 | 20.96 | 39.00 |
| Tent | 70.20 | 49.20 | 10.00 | 21.32 | 37.68 |
| TLM | 85.06 | 50.00 | 6.67 | 19.49 | 40.31 |
| TF-GRPO | 86.49 | 53.00 | 3.33 | 21.69 | 41.13 |
| TF-TTCL | 87.49 | 54.00 | 13.33 | 24.63 | 44.86 |
开放评估任务(DomainBench)
| 方法 | Geography | Agriculture | Medicine | Finance | 平均 |
|---|---|---|---|---|---|
| Base LLM | 0.2441 | 0.0876 | 0.1356 | 0.2251 | 0.1731 |
| TF-GRPO | 0.2260 | 0.0993 | 0.1147 | 0.2071 | 0.1618 |
| TF-TTCL | 0.2798 | 0.1095 | 0.2018 | 0.2863 | 0.2194 |
消融实验
| 组件 | GSM8k | Finance |
|---|---|---|
| 完整TF-TTCL | 87.49 | 0.2863 |
| 无SQA | 87.11 | 0.2851 |
| 无CED | 85.97 | 0.2639 |
| 无CRR | 87.34 | 0.2596 |
发现:CED是最关键组件,负规则贡献大于正规则。
与其他方法的对比
方法分类
| 方法 | 梯度更新 | 白盒访问 | 外部知识 | 适用场景 |
|---|---|---|---|---|
| Fine-tuning | ✅ | ✅ | 可用 | 长期部署 |
| LoRA | ✅ | ✅ | 可用 | 参数高效 |
| Tent | ✅ | ✅ | 不可用 | 域适应 |
| TLM | ✅ | ✅ | 不可用 | 分布外 |
| TF-GRPO | ✅ | ❌ | 可用 | 推理增强 |
| TF-TTCL | ❌ | ❌ | ❌ | 在线改进 |
TF-TTCL的独特优势
- 无需梯度:适合黑盒API访问的模型
- 无需外部知识:完全依赖模型自身
- 在线增量:边推理边学习
- 资源高效:无需重新训练
总结
TF-TTCL的核心贡献是提出了首个完全免训练的LLM测试时学习方法。
关键创新:
- Explore-Reflect-Steer框架:三重角色协作
- Min-PPL选择:最自信的正确/错误答案
- 规则蒸馏:将对比经验转化为显式规则
- 分离存储:正负规则分别检索
性能提升:
- GSM8K: 82.49% → 87.49% (+5.0%)
- MATH-500: 49.20% → 54.00% (+4.8%)
- AIME24: 3.33% → 13.33% (+10.0%)