TF-TTCL:LLM免训练测试时对比学习

概述

TF-TTCL(Training-Free Test-Time Contrastive Learning,arXiv:2604.13552)提出了一种无需梯度更新的测试时对比学习方法,使冻结的大语言模型能够在线自我改进。其核心框架是**“探索-反思-引导”(Explore-Reflect-Steer)**,从模型自身的推理经验中蒸馏对比监督信号。

核心贡献

  1. 完全免训练:无需梯度更新,无需外部知识
  2. 在线适应:在推理过程中实时改进
  3. 多代理协作:Teacher-Tutor-Student三重角色
  4. 跨任务泛化:闭集推理和开放评估均有效

问题背景

现有LLM的测试时适应方法存在局限:

方法梯度更新白盒访问外部知识效果
Fine-tuning可用
LoRA可用
TLM不可用中等
TF-TTCL

方法详解

1. 问题形式化

给定测试分布 (训练分布),目标:

约束条件:

  • 模型参数 冻结
  • 在线单遍协议(严格测试时设置)
  • 无外部反馈

2. 核心框架:Explore-Reflect-Steer

┌─────────────────────────────────────────────────────────────────┐
│                    TF-TTCL 三阶段框架                           │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  Stage 1: Semantic Query Augmentation (SQA)                     │
│  ┌────────────────────────────────────────────────────────┐     │
│  │  Teacher: 生成稳定锚点答案                              │     │
│  │  Tutor: 重写查询生成语义变体                            │     │
│  │  Student: 探索多样推理路径                              │     │
│  └────────────────────────────────────────────────────────┘     │
│                          ↓                                       │
│  Stage 2: Contrastive Experience Distillation (CED)            │
│  ┌────────────────────────────────────────────────────────┐     │
│  │  一致性划分候选集                                        │     │
│  │  Min-PPL选择正负样本                                    │     │
│  │  蒸馏为显式规则                                        │     │
│  └────────────────────────────────────────────────────────┘     │
│                          ↓                                       │
│  Stage 3: Contextual Rule Retrieval (CRR)                      │
│  ┌────────────────────────────────────────────────────────┐     │
│  │  正负规则分离存储                                        │     │
│  │  余弦相似度检索                                        │     │
│  │  结构化上下文注入                                      │     │
│  └────────────────────────────────────────────────────────┘     │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

Stage 1: 语义查询增强(SQA)

1.1 Teacher生成锚点

def teacher_generate(query, model, rules):
    """
    Teacher: 使用贪婪解码生成高置信度锚点
    """
    # 注入相关规则作为上下文
    context = retrieve_relevant_rules(query, rules)
    augmented_query = f"{context}\n\n问题: {query}"
    
    # 贪婪解码(稳定输出)
    answer = model.generate(
        augmented_query,
        strategy='greedy'  # 确定性强
    )
    return answer

关键:Teacher使用贪婪解码确保生成稳定、可信的锚点答案。

1.2 Tutor重写查询

def tutor_rewrite(query, model, temperature=0.7):
    """
    Tutor: 使用温度采样重写查询
    """
    prompt = f"""
    请用不同的方式重写以下问题,保持语义不变:
    
    问题: {query}
    
    重写版本:
    """
    
    # 温度采样生成语义变体
    rewritten = model.generate(
        prompt,
        temperature=temperature,
        max_tokens=100
    )
    return rewritten

关键:温度0.7平衡多样性和语义一致性。

1.3 Student采样响应

def student_explore(rewritten_queries, model, rules, n_samples=5):
    """
    Student: 对每个重写查询采样多个响应
    """
    candidates = []
    
    for q in rewritten_queries:
        context = retrieve_relevant_rules(q, rules)
        augmented_q = f"{context}\n\n问题: {q}"
        
        # 采样n_samples个响应
        for _ in range(n_samples):
            response = model.generate(
                augmented_q,
                temperature=0.9,
                do_sample=True
            )
            candidates.append(response)
    
    return candidates

Stage 2: 对比经验蒸馏(CED)

2.1 候选集划分

闭集推理任务(CRT):使用多数投票聚类

def partition_candidates_by_voting(candidates):
    """
    使用多数投票划分正负候选集
    """
    # 统计各类别票数
    votes = {}
    for cand in candidates:
        pred = extract_answer(cand)
        votes[pred] = votes.get(pred, 0) + 1
    
    # 多数类为正,少数类为负
    sorted_votes = sorted(votes.items(), key=lambda x: -x[1])
    positive_class = sorted_votes[0][0]
    
    Y_plus = [c for c in candidates if extract_answer(c) == positive_class]
    Y_minus = [c for c in candidates if extract_answer(c) != positive_class]
    
    return Y_plus, Y_minus

开放评估任务(OET):使用嵌入相似度排序

def partition_candidates_by_similarity(candidates, teacher_answer):
    """
    使用嵌入相似度划分候选集
    """
    embeddings = [model.embed(cand) for cand in candidates]
    teacher_emb = model.embed(teacher_answer)
    
    # 计算与Teacher答案的相似度
    similarities = [
        cosine_sim(emb, teacher_emb) 
        for emb in embeddings
    ]
    
    # 高相似度为正,低相似度为负
    threshold = np.median(similarities)
    Y_plus = [c for c, s in zip(candidates, similarities) if s >= threshold]
    Y_minus = [c for c, s in zip(candidates, similarities) if s < threshold]
    
    return Y_plus, Y_minus

2.2 Min-PPL选择

**困惑度(PPL)**是选择正负样本的关键:

def min_ppl_selection(Y_plus, Y_minus, model):
    """
    使用Min-PPL选择最可靠的样本
    """
    # 计算所有候选的PPL
    ppl_scores = {}
    for cand in Y_plus + Y_minus:
        ppl = compute_ppl(cand, model)
        ppl_scores[cand] = ppl
    
    # 正样本:PPL最低(最自信的正确答案)
    y_pos = min(Y_plus, key=lambda c: ppl_scores[c])
    
    # 负样本:PPL最低(最自信的错误答案)
    y_neg = min(Y_minus, key=lambda c: ppl_scores[c])
    
    return y_pos, y_neg
 
 
def compute_ppl(text, model):
    """
    计算文本的困惑度
    """
    tokens = model.tokenize(text)
    log_probs = model.get_log_probs(tokens)
    
    # 归一化的负对数似然
    ppl = np.exp(-np.mean(log_probs))
    return ppl

2.3 规则蒸馏

def distill_rule(query, y_pos, y_neg, model):
    """
    蒸馏为显式正负规则
    """
    prompt = f"""
    给定问题和两个答案:
    
    问题: {query}
    
    正确答案: {y_pos}
    错误答案: {y_neg}
    
    请总结:
    1. 为什么{y_pos}是正确的(正规则)
    2. {y_neg}哪里错了(负规则)
    
    用一句话简洁回答。
    """
    
    # 提取正负规则
    r_pos = model.generate(f"正规则: 问题'{query}'的答案应该是...", 
                           strategy='greedy')
    r_neg = model.generate(f"负规则: 问题'{query}'不应该...", 
                           strategy='greedy')
    
    return r_pos, r_neg

Stage 3: 上下文规则检索(CRR)

3.1 规则存储结构

class RuleRepository:
    """
    分离存储正负规则的仓库
    """
    def __init__(self):
        self.pos_rules = []  # (rule, embedding, query_embedding)
        self.neg_rules = []
        self.max_size = 10000
    
    def add(self, rule_pos, rule_neg, query_emb):
        """添加规则"""
        rule_pos_emb = embed(rule_pos)
        self.pos_rules.append({
            'rule': rule_pos,
            'embedding': rule_pos_emb,
            'query_emb': query_emb
        })
        
        rule_neg_emb = embed(rule_neg)
        self.neg_rules.append({
            'rule': rule_neg,
            'embedding': rule_neg_emb,
            'query_emb': query_emb
        })
        
        # FIFO修剪
        if len(self.pos_rules) > self.max_size:
            self.pos_rules.pop(0)
        if len(self.neg_rules) > self.max_size:
            self.neg_rules.pop(0)

3.2 余弦相似度检索

def retrieve_relevant_rules(query, repository, top_k=3):
    """
    检索与当前查询最相关的正负规则
    """
    query_emb = embed(query)
    
    # 检索正规则
    pos_scores = [
        cosine_sim(query_emb, r['embedding'])
        for r in repository.pos_rules
    ]
    top_pos_indices = np.argsort(pos_scores)[-top_k:]
    relevant_pos = [repository.pos_rules[i]['rule'] for i in top_pos_indices]
    
    # 检索负规则
    neg_scores = [
        cosine_sim(query_emb, r['embedding'])
        for r in repository.neg_rules
    ]
    top_neg_indices = np.argsort(neg_scores)[-top_k:]
    relevant_neg = [repository.neg_rules[i]['rule'] for i in top_neg_indices]
    
    return relevant_pos, relevant_neg

3.3 结构化上下文注入

def build_context(queries, rules_pos, rules_neg):
    """
    构建结构化上下文
    """
    context_parts = []
    
    if rules_pos:
        context_parts.append("【应该做的事】")
        for r in rules_pos:
            context_parts.append(f"- {r}")
    
    if rules_neg:
        context_parts.append("【不应该做的事】")
        for r in rules_neg:
            context_parts.append(f"- {r}")
    
    if queries:
        context_parts.append("【类似问题参考】")
        for q in queries[-2:]:  # 最近2个
            context_parts.append(f"- {q}")
    
    return "\n".join(context_parts)

完整推理流程

def tf_ttcl_inference(query, model, repository):
    """
    TF-TTCL 完整推理流程
    """
    # 1. 检索规则
    rules_pos, rules_neg = retrieve_relevant_rules(query, repository)
    
    # 2. Teacher生成锚点
    anchor = teacher_generate(query, model, 
                             build_context(None, rules_pos, rules_neg))
    candidates = [anchor]
    
    # 3. Tutor重写查询
    rewritten_queries = [tutor_rewrite(query, model)]
    
    # 4. Student采样多个响应
    candidates.extend(
        student_explore(rewritten_queries, model,
                       build_context(None, rules_pos, rules_neg),
                       n_samples=5)
    )
    
    # 5. 划分正负候选集
    Y_plus, Y_minus = partition_candidates_by_voting(candidates)
    
    # 6. Min-PPL选择
    y_pos, y_neg = min_ppl_selection(Y_plus, Y_minus, model)
    
    # 7. 蒸馏规则
    r_pos, r_neg = distill_rule(query, y_pos, y_neg, model)
    
    # 8. 更新仓库
    repository.add(r_pos, r_neg, embed(query))
    
    # 9. 返回正样本作为最终答案
    return y_pos

实验结果

闭集推理任务(Llama-3.1-8B-Instruct)

方法GSM8kMATH-500AIME24Minerva平均
Base LLM82.4949.203.3320.9639.00
Tent70.2049.2010.0021.3237.68
TLM85.0650.006.6719.4940.31
TF-GRPO86.4953.003.3321.6941.13
TF-TTCL87.4954.0013.3324.6344.86

开放评估任务(DomainBench)

方法GeographyAgricultureMedicineFinance平均
Base LLM0.24410.08760.13560.22510.1731
TF-GRPO0.22600.09930.11470.20710.1618
TF-TTCL0.27980.10950.20180.28630.2194

消融实验

组件GSM8kFinance
完整TF-TTCL87.490.2863
无SQA87.110.2851
无CED85.970.2639
无CRR87.340.2596

发现:CED是最关键组件,负规则贡献大于正规则。


与其他方法的对比

方法分类

方法梯度更新白盒访问外部知识适用场景
Fine-tuning可用长期部署
LoRA可用参数高效
Tent不可用域适应
TLM不可用分布外
TF-GRPO可用推理增强
TF-TTCL在线改进

TF-TTCL的独特优势

  1. 无需梯度:适合黑盒API访问的模型
  2. 无需外部知识:完全依赖模型自身
  3. 在线增量:边推理边学习
  4. 资源高效:无需重新训练

总结

TF-TTCL的核心贡献是提出了首个完全免训练的LLM测试时学习方法。

关键创新

  1. Explore-Reflect-Steer框架:三重角色协作
  2. Min-PPL选择:最自信的正确/错误答案
  3. 规则蒸馏:将对比经验转化为显式规则
  4. 分离存储:正负规则分别检索

性能提升

  • GSM8K: 82.49% → 87.49% (+5.0%)
  • MATH-500: 49.20% → 54.00% (+4.8%)
  • AIME24: 3.33% → 13.33% (+10.0%)

参考