概述

测试时学习(Test-Time Learning, TTL)是一种在推理阶段动态调整模型行为的范式。FTTT(Feedback Test-Time Training) 提出将测试时学习重新定义为**反馈利用(Feedback Utilization)**优化问题,通过可学习的测试时优化器(OpTune)实现高效的自适应推理增强。1


1. 问题背景

1.1 测试时学习的挑战

传统的测试时适应方法面临以下挑战:

方法核心思想局限性
TTT训练时相同的梯度下降需要大量测试数据
Test-Time DropoutMonte Carlo Dropout效果有限
特征归一化BN统计量调整依赖预定义变换
FTTT反馈驱动的自适应通用且高效

1.2 FTTT的核心洞察

FTTT的核心观察是:

推理失败往往不是”不知道答案”,而是”不知道如何表达正确的推理过程”。

这意味着可以通过反馈机制引导模型生成更好的推理链。

class FeedbackDrivenReasoning:
    """
    反馈驱动的推理增强
    """
    def __init__(self, model, feedback_model):
        self.model = model
        self.feedback_model = feedback_model  # 评估生成质量的模型
    
    def generate_with_feedback(self, prompt, max_iterations=3):
        """
        迭代式生成-反馈-改进循环
        """
        current_output = None
        
        for iteration in range(max_iterations):
            # 生成
            if current_output is None:
                input_text = prompt
            else:
                # 将反馈注入输入
                input_text = f"{prompt}\n\nPrevious attempt:\n{current_output}\n\nFeedback: {self.feedback}"
            
            current_output = self.model.generate(input_text)
            
            # 评估反馈
            feedback_result = self.feedback_model.evaluate(
                prompt, 
                current_output
            )
            
            if feedback_result.is_good_enough:
                break
            
            self.feedback = feedback_result.feedback_text
        
        return current_output

2. OpTune:可学习测试时优化器

2.1 设计动机

传统的测试时适应使用固定的优化器(如SGD、Adam),这假设测试数据和训练数据具有相同的分布特性。OpTune提出学习测试时优化器来适应测试数据的特点。

2.2 OpTune架构

class OpTuneOptimizer(nn.Module):
    """
    可学习的测试时优化器
    
    核心思想:用神经网络参数化优化器策略
    """
    def __init__(self, hidden_dim=64):
        super().__init__()
        
        # 梯度处理网络
        self.gradient_encoder = nn.Sequential(
            nn.Linear(1, hidden_dim),  # 标量梯度
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU()
        )
        
        # 状态处理网络(维护优化器状态)
        self.state_encoder = nn.Sequential(
            nn.Linear(1, hidden_dim),  # 优化器状态
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU()
        )
        
        # 更新策略网络
        self.update_policy = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, 1)  # 输出更新步长
        )
        
        # 动量/二阶信息处理
        self.momentum_estimator = nn.GRU(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            batch_first=True
        )
    
    def forward(self, grad, state, step):
        """
        根据当前梯度和状态生成更新
        
        Args:
            grad: 当前梯度(标量或向量)
            state: 优化器状态(如动量)
            step: 当前步数
        
        Returns:
            update: 更新量
            new_state: 新的优化器状态
        """
        # 编码梯度
        grad_encoded = self.gradient_encoder(
            grad.unsqueeze(-1) if grad.dim() == 1 else grad
        )
        
        # 编码状态
        state_encoded = self.state_encoder(
            state.unsqueeze(-1) if state.dim() == 1 else state
        )
        
        # 更新策略
        combined = torch.cat([grad_encoded, state_encoded], dim=-1)
        update_magnitude = self.update_policy(combined)
        
        # 估计动量
        momentum_input = grad_encoded.unsqueeze(1)
        _, momentum_state = self.momentum_estimator(momentum_input)
        
        # 生成更新
        update = -torch.tanh(update_magnitude) * grad
        
        return update, momentum_state.squeeze(1)

2.3 训练OpTune

OpTune通过元学习(Meta-Learning)训练:

class OpTuneTrainer:
    """OpTune元学习训练器"""
    def __init__(self, model, op_tune, inner_lr=0.01, outer_lr=1e-4):
        self.model = model
        self.op_tune = op_tune
        self.inner_lr = inner_lr
        self.optimizer = torch.optim.Adam(op_tune.parameters(), lr=outer_lr)
    
    def meta_train_step(self, support_batch, query_batch):
        """
        MAML风格的内-外循环训练
        
        Args:
            support_batch: 支持集(用于内循环)
            query_batch: 查询集(用于外循环评估)
        """
        # 评估初始模型在查询集上的损失
        meta_initial_loss = self.evaluate(self.model, query_batch)
        
        # 内循环:使用OpTune更新模型
        adapted_model = self.clone_model(self.model)
        optimizer_state = torch.zeros_like(
            list(adapted_model.parameters())[0]
        )
        
        for step in range(self.num_inner_steps):
            # 计算支持集损失
            loss = self.evaluate(adapted_model, support_batch)
            
            # 获取梯度
            grads = torch.autograd.grad(
                loss, 
                adapted_model.parameters(),
                create_graph=True
            )
            
            # 使用OpTune生成更新
            updates = []
            new_state = []
            param_idx = 0
            
            for param in adapted_model.parameters():
                update, new_s = self.op_tune(
                    grads[param_idx].norm(),
                    optimizer_state[param_idx],
                    step
                )
                updates.append(update * self.inner_lr)
                new_state.append(new_s)
                param_idx += 1
            
            # 应用更新
            adapted_model = self.apply_updates(adapted_model, updates)
            optimizer_state = torch.stack(new_state)
        
        # 外循环:评估更新后的模型
        meta_loss = self.evaluate(adapted_model, query_batch)
        
        # 更新OpTune
        self.optimizer.zero_grad()
        # 注意:这里需要正确处理梯度
        # 实际实现中需要更复杂的梯度处理
        
        return meta_loss
    
    def clone_model(self, model):
        """克隆模型参数"""
        return type(model)(**model.config).__dict__.update(
            {k: v.clone() for k, v in model.named_parameters()}
        )
    
    def apply_updates(self, model, updates):
        """应用参数更新"""
        for param, update in zip(model.parameters(), updates):
            param.data.add_(update)
        return model
    
    def evaluate(self, model, batch):
        """评估模型"""
        # 根据具体任务定义评估方式
        pass

3. 反馈利用作为优化

3.1 形式化定义

FTTT将测试时学习形式化为优化问题:

目标函数

其中 是测试时损失, 是正则化项。

反馈驱动的梯度估计

FTTT使用反馈模型来估计梯度:

class FeedbackGradientEstimator:
    """
    反馈梯度估计器
    
    核心思想:用有限差分法估计参数调整对反馈的影响
    """
    def __init__(self, model, feedback_model, epsilon=1e-3):
        self.model = model
        self.feedback_model = feedback_model
        self.epsilon = epsilon
    
    def estimate_gradient(self, prompt, response, target_response):
        """
        估计参数调整方向
        
        使用有限差分法:
        ∂feedback/∂θ ≈ (feedback(θ+ε) - feedback(θ-ε)) / (2ε)
        """
        # 记录原始参数
        original_params = {
            name: param.clone() 
            for name, param in self.model.named_parameters()
        }
        
        # 获取当前反馈
        current_feedback = self.feedback_model.evaluate(
            prompt, response
        ).score
        
        gradients = {}
        
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                # 正向扰动
                param.data.add_(self.epsilon)
                pos_response = self.generate_response(prompt)
                pos_feedback = self.feedback_model.evaluate(
                    prompt, pos_response
                ).score
                
                # 负向扰动
                param.data.sub_(2 * self.epsilon)
                neg_response = self.generate_response(prompt)
                neg_feedback = self.feedback_model.evaluate(
                    prompt, neg_response
                ).score
                
                # 恢复原始参数
                param.data.copy_(original_params[name])
                
                # 有限差分估计
                grad = (pos_feedback - neg_feedback) / (2 * self.epsilon)
                gradients[name] = grad
        
        return gradients
    
    def generate_response(self, prompt):
        """使用当前模型生成响应"""
        with torch.no_grad():
            return self.model.generate(prompt)

3.2 反馈类型

FTTT支持多种反馈类型:

class FeedbackTypes:
    """反馈类型枚举"""
    
    # 1. 评分反馈:直接给出质量评分
    SCORE = "score"
    
    # 2. 比较反馈:指出哪个更好
    COMPARISON = "comparison"
    
    # 3. 自然语言反馈:详细的文字反馈
    NATURAL_LANGUAGE = "natural_language"
    
    # 4. 约束反馈:指出违反的约束
    CONSTRAINT = "constraint"
    
    # 5. 示例反馈:提供正确示例
    EXAMPLE = "example"
 
class UnifiedFeedback:
    """统一反馈接口"""
    
    def __init__(self, feedback_type):
        self.type = feedback_type
    
    @classmethod
    def from_score(cls, score):
        """从评分创建反馈"""
        return {"type": cls.SCORE, "value": score}
    
    @classmethod
    def from_comparison(cls, is_better):
        """从比较创建反馈"""
        return {"type": cls.COMPARISON, "better": is_better}
    
    @classmethod
    def from_natural_language(cls, text):
        """从自然语言创建反馈"""
        return {"type": cls.NATURAL_LANGUAGE, "text": text}

3.3 优化算法

FTTT的优化算法:

class FTTTOptimizer:
    """FTTT优化器"""
    def __init__(self, model, feedback_model, op_tune, lr=0.01):
        self.model = model
        self.feedback_model = feedback_model
        self.op_tune = op_tune
        self.lr = lr
        self.state = {}  # 优化器状态
    
    def step(self, prompt, response):
        """
        执行一步优化
        
        流程:
        1. 评估当前反馈
        2. 估计梯度
        3. 使用OpTune生成更新
        """
        # 评估反馈
        feedback = self.feedback_model.evaluate(prompt, response)
        
        # 初始化状态
        if not self.state:
            for name, _ in self.model.named_parameters():
                self.state[name] = torch.zeros(1, device=next(self.model.parameters()).device)
        
        # 估计梯度
        gradients = self.feedback_gradient_estimator.estimate_gradient(
            prompt, response, feedback.target
        )
        
        # 使用OpTune处理梯度
        processed_gradients = {}
        new_state = {}
        
        for name, grad in gradients.items():
            # OpTune处理
            update, new_s = self.op_tune(
                grad, 
                self.state[name],
                step=len(self.state)  # 简化
            )
            processed_gradients[name] = update * self.lr
            new_state[name] = new_s
        
        # 应用更新
        for name, param in self.model.named_parameters():
            param.data.add_(processed_gradients[name])
        
        self.state = new_state
        
        return feedback

4. 完整FTTT框架

4.1 主算法

class FTTT:
    """
    Feedback Test-Time Training 主类
    """
    def __init__(
        self, 
        model, 
        feedback_model,
        op_tune=None,
        max_iterations=5,
        early_stop_threshold=0.95
    ):
        self.model = model
        self.feedback_model = feedback_model
        self.op_tune = op_tune or OpTuneOptimizer()
        self.max_iterations = max_iterations
        self.early_stop_threshold = early_stop_threshold
    
    def infer(self, prompt, return_iterations=False):
        """
        测试时推理
        
        Args:
            prompt: 输入提示
            return_iterations: 是否返回迭代信息
        
        Returns:
            best_response: 最佳响应
            metadata: 迭代信息(可选)
        """
        best_response = None
        best_score = -float('inf')
        iteration_info = []
        
        optimizer = FTTTOptimizer(
            self.model,
            self.feedback_model,
            self.op_tune
        )
        
        # 保存初始参数(用于恢复)
        initial_params = {
            name: param.clone()
            for name, param in self.model.named_parameters()
        }
        
        current_response = self.model.generate(prompt)
        
        for iteration in range(self.max_iterations):
            # 评估当前响应
            feedback = self.feedback_model.evaluate(prompt, current_response)
            
            iteration_info.append({
                'iteration': iteration,
                'response': current_response,
                'score': feedback.score,
                'feedback': feedback
            })
            
            # 检查是否达到早停条件
            if feedback.score >= self.early_stop_threshold:
                break
            
            # 更新最佳响应
            if feedback.score > best_score:
                best_score = feedback.score
                best_response = current_response
            
            # 执行FTTT优化步骤
            optimizer.step(prompt, current_response)
            
            # 使用更新后的模型重新生成
            current_response = self.model.generate(prompt)
        
        # 如果没有改进,恢复初始参数
        if best_response is None:
            best_response = current_response
        else:
            for name, param in self.model.named_parameters():
                param.data.copy_(initial_params[name])
        
        if return_iterations:
            return best_response, iteration_info
        return best_response
    
    def batch_infer(self, prompts, parallel=True):
        """
        批量推理
        """
        if parallel:
            # 并行处理(需要更多GPU内存)
            results = [self.infer(p) for p in prompts]
        else:
            # 串行处理
            results = []
            for p in tqdm(prompts, desc="FTTT Inference"):
                results.append(self.infer(p))
        
        return results

4.2 反馈模型

class SimpleScoringFeedbackModel:
    """
    简单评分反馈模型
    
    适用于有明确正确答案的任务
    """
    def __init__(self, reward_model=None):
        self.reward_model = reward_model
    
    def evaluate(self, prompt, response):
        """评估响应质量"""
        if self.reward_model:
            score = self.reward_model.get_score(prompt, response)
        else:
            # 使用规则评分
            score = self.rule_based_score(prompt, response)
        
        return FeedbackResult(
            score=score,
            is_good_enough=score > 0.8,
            target=self._generate_target(prompt)
        )
    
    def rule_based_score(self, prompt, response):
        """基于规则的评分"""
        # 简化实现
        return 0.5
 
 
class LLMScoringFeedbackModel:
    """
    使用LLM进行评分反馈
    
    适用于开放式任务
    """
    def __init__(self, judge_model):
        self.judge_model = judge_model
    
    def evaluate(self, prompt, response):
        """使用LLM评判生成质量"""
        judge_prompt = f"""
        请评估以下回答的质量:
 
        问题:{prompt}
        回答:{response}
 
        请从以下维度评分(0-1):
        1. 准确性
        2. 完整性
        3. 清晰度
        4. 相关性
 
        最终综合评分:
        """
        
        with torch.no_grad():
            judgment = self.judge_model.generate(judge_prompt)
        
        # 解析评分
        score = self._parse_score(judgment)
        
        return FeedbackResult(
            score=score,
            is_good_enough=score > 0.7,
            target=judgment
        )
    
    def _parse_score(self, text):
        """解析评分"""
        # 简化实现
        import re
        numbers = re.findall(r'\d+\.?\d*', text)
        if numbers:
            return float(numbers[0]) / 10  # 假设评分是0-10
        return 0.5

5. 实验结果

5.1 主要结果

在四个推理数据集上的结果:

数据集任务类型Base+FTTT提升
MATH数学推理52.8%61.4%+8.6%
GSM8K数学应用题76.3%83.1%+6.8%
HellaSwag常识推理79.2%81.5%+2.3%
BIG-Bench Hard复杂推理68.4%74.2%+5.8%

5.2 迭代分析

FTTT的迭代效果分析:

ITERATION_ANALYSIS = {
    "MATH": {
        "iter1": 52.8,
        "iter2": 56.2,
        "iter3": 58.9,
        "iter4": 60.5,
        "iter5": 61.4,
        "convergence": "iter4"
    },
    "GSM8K": {
        "iter1": 76.3,
        "iter2": 79.8,
        "iter3": 81.9,
        "iter4": 82.8,
        "iter5": 83.1,
        "convergence": "iter3"
    }
}
# 观察:大多数任务在3-4次迭代后收敛

5.3 OpTune vs 固定优化器

优化器MATHGSM8K平均提升
SGD58.9%81.2%+5.5%
Adam60.1%82.4%+6.6%
OpTune61.4%83.1%+7.5%

6. 与其他方法的对比

6.1 方法对比表

方法计算开销需要训练反馈需求通用性
TTT
MC Dropout
Test-Time BN
Self-Consistency
FTTT是(轻量)

6.2 互补性

FTTT可以与多种方法组合:

class FTTTWithSelfConsistency:
    """
    FTTT + Self-Consistency 组合
    """
    def __init__(self, fttt, num_samples=8):
        self.fttt = fttt
        self.num_samples = num_samples
    
    def infer(self, prompt):
        """
        1. 生成多个候选响应(使用Self-Consistency采样)
        2. 选择最佳候选进行FTTT优化
        """
        # 采样多个响应
        candidates = []
        for _ in range(self.num_samples):
            candidate = self.fttt.model.generate(prompt)
            candidates.append(candidate)
        
        # 选择最一致的响应
        best_candidate = self.select_most_consistent(candidates)
        
        # 使用FTTT优化
        return self.fttt.infer(prompt + best_candidate)

7. 实践指南

7.1 何时使用FTTT

FTTT适合以下场景:

  1. 推理质量不足:模型基本正确但表达不够清晰
  2. 有可用的反馈信号:评分模型、验证器等
  3. 允许一定的计算开销:每次推理允许额外2-5次生成
  4. 任务有明确的正确性标准:可以设计反馈模型

7.2 超参数设置

FTTT_HYPERPARAMETERS = {
    "max_iterations": {
        "default": 5,
        "range": [3, 10],
        "description": "最大迭代次数"
    },
    "early_stop_threshold": {
        "default": 0.95,
        "range": [0.8, 0.99],
        "description": "早停阈值"
    },
    "learning_rate": {
        "default": 0.01,
        "range": [0.001, 0.1],
        "description": "优化学习率"
    },
    "num_samples": {
        "default": 8,
        "range": [4, 32],
        "description": "候选样本数量(用于Self-Consistency)"
    }
}

7.3 反馈模型设计

反馈模型是FTTT效果的关键:

class TaskSpecificFeedbackModel:
    """
    任务特定的反馈模型设计指南
    """
    
    @staticmethod
    def for_math_problems(verifier):
        """数学问题反馈模型"""
        def evaluate(prompt, response):
            # 提取答案
            extracted = extract_answer(response)
            # 验证正确性
            is_correct = verifier.verify(prompt, extracted)
            return FeedbackResult(
                score=1.0 if is_correct else 0.0,
                is_good_enough=is_correct,
                target=extracted
            )
        return evaluate
    
    @staticmethod
    def for_code_generation(executor):
        """代码生成反馈模型"""
        def evaluate(prompt, response):
            # 提取代码
            code = extract_code(response)
            # 执行测试用例
            result = executor.run_tests(code)
            return FeedbackResult(
                score=result.pass_rate,
                is_good_enough=result.pass_rate > 0.8,
                target=result.expected_output
            )
        return evaluate
    
    @staticmethod
    def for_open_ended(judge_model):
        """开放式任务反馈模型"""
        def evaluate(prompt, response):
            # 使用LLM评判
            judgment = judge_model.judge(prompt, response)
            return FeedbackResult(
                score=judgment.score,
                is_good_enough=judgment.score > 0.8,
                target=judgment.feedback
            )
        return evaluate

8. 代码示例

8.1 完整使用示例

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
 
# 1. 加载模型
model_name = "meta-llama/Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
 
# 2. 初始化反馈模型(以数学问题为例)
feedback_model = SimpleScoringFeedbackModel()
 
# 3. 初始化FTTT
fttt = FTTT(
    model=model,
    feedback_model=feedback_model,
    max_iterations=5,
    early_stop_threshold=0.95
)
 
# 4. 推理
test_prompts = [
    "求 x^2 - 5x + 6 = 0 的解",
    "一个三角形,边长分别为3, 4, 5,求其面积"
]
 
results = fttt.batch_infer(test_prompts, parallel=False)
 
for prompt, result in zip(test_prompts, results):
    print(f"Prompt: {prompt}")
    print(f"Result: {result}")
    print("-" * 50)

8.2 自定义反馈模型

class CustomFeedbackModel:
    """自定义反馈模型示例"""
    
    def __init__(self, reward_model=None, rules=None):
        self.reward_model = reward_model
        self.rules = rules or []
    
    def evaluate(self, prompt, response):
        """
        自定义评估逻辑
        """
        # 1. 规则检查
        rule_score = self._apply_rules(response)
        
        # 2. 奖励模型评分
        if self.reward_model:
            reward_score = self.reward_model.get_score(prompt, response)
        else:
            reward_score = 0.5
        
        # 3. 综合评分
        final_score = 0.3 * rule_score + 0.7 * reward_score
        
        # 4. 生成反馈文本
        feedback_text = self._generate_feedback(
            prompt, response, rule_score, reward_score
        )
        
        return FeedbackResult(
            score=final_score,
            is_good_enough=final_score > 0.8,
            target=feedback_text
        )
    
    def _apply_rules(self, response):
        """应用规则评分"""
        score = 1.0
        for rule in self.rules:
            if not rule.check(response):
                score *= rule.penalty
        return score
    
    def _generate_feedback(self, prompt, response, rule_score, reward_score):
        """生成反馈文本"""
        # 根据评分生成针对性反馈
        if rule_score < 0.5:
            return "请检查答案的格式和完整性"
        if reward_score < 0.5:
            return "答案正确但表达不够清晰,请更详细地解释推理过程"
        return "回答质量良好"

参考


相关阅读

Footnotes

  1. 本文档基于FTTT(Feedback Test-Time Training)论文整理。相关论文发表在ICLR/NeurIPS 2025。