概述
测试时学习(Test-Time Learning, TTL)是一种在推理阶段动态调整模型行为的范式。FTTT(Feedback Test-Time Training) 提出将测试时学习重新定义为**反馈利用(Feedback Utilization)**优化问题,通过可学习的测试时优化器(OpTune)实现高效的自适应推理增强。1
1. 问题背景
1.1 测试时学习的挑战
传统的测试时适应方法面临以下挑战:
| 方法 | 核心思想 | 局限性 |
|---|---|---|
| TTT | 训练时相同的梯度下降 | 需要大量测试数据 |
| Test-Time Dropout | Monte Carlo Dropout | 效果有限 |
| 特征归一化 | BN统计量调整 | 依赖预定义变换 |
| FTTT | 反馈驱动的自适应 | 通用且高效 |
1.2 FTTT的核心洞察
FTTT的核心观察是:
推理失败往往不是”不知道答案”,而是”不知道如何表达正确的推理过程”。
这意味着可以通过反馈机制引导模型生成更好的推理链。
class FeedbackDrivenReasoning:
"""
反馈驱动的推理增强
"""
def __init__(self, model, feedback_model):
self.model = model
self.feedback_model = feedback_model # 评估生成质量的模型
def generate_with_feedback(self, prompt, max_iterations=3):
"""
迭代式生成-反馈-改进循环
"""
current_output = None
for iteration in range(max_iterations):
# 生成
if current_output is None:
input_text = prompt
else:
# 将反馈注入输入
input_text = f"{prompt}\n\nPrevious attempt:\n{current_output}\n\nFeedback: {self.feedback}"
current_output = self.model.generate(input_text)
# 评估反馈
feedback_result = self.feedback_model.evaluate(
prompt,
current_output
)
if feedback_result.is_good_enough:
break
self.feedback = feedback_result.feedback_text
return current_output2. OpTune:可学习测试时优化器
2.1 设计动机
传统的测试时适应使用固定的优化器(如SGD、Adam),这假设测试数据和训练数据具有相同的分布特性。OpTune提出学习测试时优化器来适应测试数据的特点。
2.2 OpTune架构
class OpTuneOptimizer(nn.Module):
"""
可学习的测试时优化器
核心思想:用神经网络参数化优化器策略
"""
def __init__(self, hidden_dim=64):
super().__init__()
# 梯度处理网络
self.gradient_encoder = nn.Sequential(
nn.Linear(1, hidden_dim), # 标量梯度
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU()
)
# 状态处理网络(维护优化器状态)
self.state_encoder = nn.Sequential(
nn.Linear(1, hidden_dim), # 优化器状态
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU()
)
# 更新策略网络
self.update_policy = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 1) # 输出更新步长
)
# 动量/二阶信息处理
self.momentum_estimator = nn.GRU(
input_size=hidden_dim,
hidden_size=hidden_dim,
num_layers=1,
batch_first=True
)
def forward(self, grad, state, step):
"""
根据当前梯度和状态生成更新
Args:
grad: 当前梯度(标量或向量)
state: 优化器状态(如动量)
step: 当前步数
Returns:
update: 更新量
new_state: 新的优化器状态
"""
# 编码梯度
grad_encoded = self.gradient_encoder(
grad.unsqueeze(-1) if grad.dim() == 1 else grad
)
# 编码状态
state_encoded = self.state_encoder(
state.unsqueeze(-1) if state.dim() == 1 else state
)
# 更新策略
combined = torch.cat([grad_encoded, state_encoded], dim=-1)
update_magnitude = self.update_policy(combined)
# 估计动量
momentum_input = grad_encoded.unsqueeze(1)
_, momentum_state = self.momentum_estimator(momentum_input)
# 生成更新
update = -torch.tanh(update_magnitude) * grad
return update, momentum_state.squeeze(1)2.3 训练OpTune
OpTune通过元学习(Meta-Learning)训练:
class OpTuneTrainer:
"""OpTune元学习训练器"""
def __init__(self, model, op_tune, inner_lr=0.01, outer_lr=1e-4):
self.model = model
self.op_tune = op_tune
self.inner_lr = inner_lr
self.optimizer = torch.optim.Adam(op_tune.parameters(), lr=outer_lr)
def meta_train_step(self, support_batch, query_batch):
"""
MAML风格的内-外循环训练
Args:
support_batch: 支持集(用于内循环)
query_batch: 查询集(用于外循环评估)
"""
# 评估初始模型在查询集上的损失
meta_initial_loss = self.evaluate(self.model, query_batch)
# 内循环:使用OpTune更新模型
adapted_model = self.clone_model(self.model)
optimizer_state = torch.zeros_like(
list(adapted_model.parameters())[0]
)
for step in range(self.num_inner_steps):
# 计算支持集损失
loss = self.evaluate(adapted_model, support_batch)
# 获取梯度
grads = torch.autograd.grad(
loss,
adapted_model.parameters(),
create_graph=True
)
# 使用OpTune生成更新
updates = []
new_state = []
param_idx = 0
for param in adapted_model.parameters():
update, new_s = self.op_tune(
grads[param_idx].norm(),
optimizer_state[param_idx],
step
)
updates.append(update * self.inner_lr)
new_state.append(new_s)
param_idx += 1
# 应用更新
adapted_model = self.apply_updates(adapted_model, updates)
optimizer_state = torch.stack(new_state)
# 外循环:评估更新后的模型
meta_loss = self.evaluate(adapted_model, query_batch)
# 更新OpTune
self.optimizer.zero_grad()
# 注意:这里需要正确处理梯度
# 实际实现中需要更复杂的梯度处理
return meta_loss
def clone_model(self, model):
"""克隆模型参数"""
return type(model)(**model.config).__dict__.update(
{k: v.clone() for k, v in model.named_parameters()}
)
def apply_updates(self, model, updates):
"""应用参数更新"""
for param, update in zip(model.parameters(), updates):
param.data.add_(update)
return model
def evaluate(self, model, batch):
"""评估模型"""
# 根据具体任务定义评估方式
pass3. 反馈利用作为优化
3.1 形式化定义
FTTT将测试时学习形式化为优化问题:
目标函数:
其中 是测试时损失, 是正则化项。
反馈驱动的梯度估计:
FTTT使用反馈模型来估计梯度:
class FeedbackGradientEstimator:
"""
反馈梯度估计器
核心思想:用有限差分法估计参数调整对反馈的影响
"""
def __init__(self, model, feedback_model, epsilon=1e-3):
self.model = model
self.feedback_model = feedback_model
self.epsilon = epsilon
def estimate_gradient(self, prompt, response, target_response):
"""
估计参数调整方向
使用有限差分法:
∂feedback/∂θ ≈ (feedback(θ+ε) - feedback(θ-ε)) / (2ε)
"""
# 记录原始参数
original_params = {
name: param.clone()
for name, param in self.model.named_parameters()
}
# 获取当前反馈
current_feedback = self.feedback_model.evaluate(
prompt, response
).score
gradients = {}
for name, param in self.model.named_parameters():
if param.requires_grad:
# 正向扰动
param.data.add_(self.epsilon)
pos_response = self.generate_response(prompt)
pos_feedback = self.feedback_model.evaluate(
prompt, pos_response
).score
# 负向扰动
param.data.sub_(2 * self.epsilon)
neg_response = self.generate_response(prompt)
neg_feedback = self.feedback_model.evaluate(
prompt, neg_response
).score
# 恢复原始参数
param.data.copy_(original_params[name])
# 有限差分估计
grad = (pos_feedback - neg_feedback) / (2 * self.epsilon)
gradients[name] = grad
return gradients
def generate_response(self, prompt):
"""使用当前模型生成响应"""
with torch.no_grad():
return self.model.generate(prompt)3.2 反馈类型
FTTT支持多种反馈类型:
class FeedbackTypes:
"""反馈类型枚举"""
# 1. 评分反馈:直接给出质量评分
SCORE = "score"
# 2. 比较反馈:指出哪个更好
COMPARISON = "comparison"
# 3. 自然语言反馈:详细的文字反馈
NATURAL_LANGUAGE = "natural_language"
# 4. 约束反馈:指出违反的约束
CONSTRAINT = "constraint"
# 5. 示例反馈:提供正确示例
EXAMPLE = "example"
class UnifiedFeedback:
"""统一反馈接口"""
def __init__(self, feedback_type):
self.type = feedback_type
@classmethod
def from_score(cls, score):
"""从评分创建反馈"""
return {"type": cls.SCORE, "value": score}
@classmethod
def from_comparison(cls, is_better):
"""从比较创建反馈"""
return {"type": cls.COMPARISON, "better": is_better}
@classmethod
def from_natural_language(cls, text):
"""从自然语言创建反馈"""
return {"type": cls.NATURAL_LANGUAGE, "text": text}3.3 优化算法
FTTT的优化算法:
class FTTTOptimizer:
"""FTTT优化器"""
def __init__(self, model, feedback_model, op_tune, lr=0.01):
self.model = model
self.feedback_model = feedback_model
self.op_tune = op_tune
self.lr = lr
self.state = {} # 优化器状态
def step(self, prompt, response):
"""
执行一步优化
流程:
1. 评估当前反馈
2. 估计梯度
3. 使用OpTune生成更新
"""
# 评估反馈
feedback = self.feedback_model.evaluate(prompt, response)
# 初始化状态
if not self.state:
for name, _ in self.model.named_parameters():
self.state[name] = torch.zeros(1, device=next(self.model.parameters()).device)
# 估计梯度
gradients = self.feedback_gradient_estimator.estimate_gradient(
prompt, response, feedback.target
)
# 使用OpTune处理梯度
processed_gradients = {}
new_state = {}
for name, grad in gradients.items():
# OpTune处理
update, new_s = self.op_tune(
grad,
self.state[name],
step=len(self.state) # 简化
)
processed_gradients[name] = update * self.lr
new_state[name] = new_s
# 应用更新
for name, param in self.model.named_parameters():
param.data.add_(processed_gradients[name])
self.state = new_state
return feedback4. 完整FTTT框架
4.1 主算法
class FTTT:
"""
Feedback Test-Time Training 主类
"""
def __init__(
self,
model,
feedback_model,
op_tune=None,
max_iterations=5,
early_stop_threshold=0.95
):
self.model = model
self.feedback_model = feedback_model
self.op_tune = op_tune or OpTuneOptimizer()
self.max_iterations = max_iterations
self.early_stop_threshold = early_stop_threshold
def infer(self, prompt, return_iterations=False):
"""
测试时推理
Args:
prompt: 输入提示
return_iterations: 是否返回迭代信息
Returns:
best_response: 最佳响应
metadata: 迭代信息(可选)
"""
best_response = None
best_score = -float('inf')
iteration_info = []
optimizer = FTTTOptimizer(
self.model,
self.feedback_model,
self.op_tune
)
# 保存初始参数(用于恢复)
initial_params = {
name: param.clone()
for name, param in self.model.named_parameters()
}
current_response = self.model.generate(prompt)
for iteration in range(self.max_iterations):
# 评估当前响应
feedback = self.feedback_model.evaluate(prompt, current_response)
iteration_info.append({
'iteration': iteration,
'response': current_response,
'score': feedback.score,
'feedback': feedback
})
# 检查是否达到早停条件
if feedback.score >= self.early_stop_threshold:
break
# 更新最佳响应
if feedback.score > best_score:
best_score = feedback.score
best_response = current_response
# 执行FTTT优化步骤
optimizer.step(prompt, current_response)
# 使用更新后的模型重新生成
current_response = self.model.generate(prompt)
# 如果没有改进,恢复初始参数
if best_response is None:
best_response = current_response
else:
for name, param in self.model.named_parameters():
param.data.copy_(initial_params[name])
if return_iterations:
return best_response, iteration_info
return best_response
def batch_infer(self, prompts, parallel=True):
"""
批量推理
"""
if parallel:
# 并行处理(需要更多GPU内存)
results = [self.infer(p) for p in prompts]
else:
# 串行处理
results = []
for p in tqdm(prompts, desc="FTTT Inference"):
results.append(self.infer(p))
return results4.2 反馈模型
class SimpleScoringFeedbackModel:
"""
简单评分反馈模型
适用于有明确正确答案的任务
"""
def __init__(self, reward_model=None):
self.reward_model = reward_model
def evaluate(self, prompt, response):
"""评估响应质量"""
if self.reward_model:
score = self.reward_model.get_score(prompt, response)
else:
# 使用规则评分
score = self.rule_based_score(prompt, response)
return FeedbackResult(
score=score,
is_good_enough=score > 0.8,
target=self._generate_target(prompt)
)
def rule_based_score(self, prompt, response):
"""基于规则的评分"""
# 简化实现
return 0.5
class LLMScoringFeedbackModel:
"""
使用LLM进行评分反馈
适用于开放式任务
"""
def __init__(self, judge_model):
self.judge_model = judge_model
def evaluate(self, prompt, response):
"""使用LLM评判生成质量"""
judge_prompt = f"""
请评估以下回答的质量:
问题:{prompt}
回答:{response}
请从以下维度评分(0-1):
1. 准确性
2. 完整性
3. 清晰度
4. 相关性
最终综合评分:
"""
with torch.no_grad():
judgment = self.judge_model.generate(judge_prompt)
# 解析评分
score = self._parse_score(judgment)
return FeedbackResult(
score=score,
is_good_enough=score > 0.7,
target=judgment
)
def _parse_score(self, text):
"""解析评分"""
# 简化实现
import re
numbers = re.findall(r'\d+\.?\d*', text)
if numbers:
return float(numbers[0]) / 10 # 假设评分是0-10
return 0.55. 实验结果
5.1 主要结果
在四个推理数据集上的结果:
| 数据集 | 任务类型 | Base | +FTTT | 提升 |
|---|---|---|---|---|
| MATH | 数学推理 | 52.8% | 61.4% | +8.6% |
| GSM8K | 数学应用题 | 76.3% | 83.1% | +6.8% |
| HellaSwag | 常识推理 | 79.2% | 81.5% | +2.3% |
| BIG-Bench Hard | 复杂推理 | 68.4% | 74.2% | +5.8% |
5.2 迭代分析
FTTT的迭代效果分析:
ITERATION_ANALYSIS = {
"MATH": {
"iter1": 52.8,
"iter2": 56.2,
"iter3": 58.9,
"iter4": 60.5,
"iter5": 61.4,
"convergence": "iter4"
},
"GSM8K": {
"iter1": 76.3,
"iter2": 79.8,
"iter3": 81.9,
"iter4": 82.8,
"iter5": 83.1,
"convergence": "iter3"
}
}
# 观察:大多数任务在3-4次迭代后收敛5.3 OpTune vs 固定优化器
| 优化器 | MATH | GSM8K | 平均提升 |
|---|---|---|---|
| SGD | 58.9% | 81.2% | +5.5% |
| Adam | 60.1% | 82.4% | +6.6% |
| OpTune | 61.4% | 83.1% | +7.5% |
6. 与其他方法的对比
6.1 方法对比表
| 方法 | 计算开销 | 需要训练 | 反馈需求 | 通用性 |
|---|---|---|---|---|
| TTT | 高 | 是 | 无 | 低 |
| MC Dropout | 中 | 否 | 无 | 中 |
| Test-Time BN | 低 | 是 | 无 | 低 |
| Self-Consistency | 高 | 否 | 无 | 高 |
| FTTT | 中 | 是(轻量) | 是 | 高 |
6.2 互补性
FTTT可以与多种方法组合:
class FTTTWithSelfConsistency:
"""
FTTT + Self-Consistency 组合
"""
def __init__(self, fttt, num_samples=8):
self.fttt = fttt
self.num_samples = num_samples
def infer(self, prompt):
"""
1. 生成多个候选响应(使用Self-Consistency采样)
2. 选择最佳候选进行FTTT优化
"""
# 采样多个响应
candidates = []
for _ in range(self.num_samples):
candidate = self.fttt.model.generate(prompt)
candidates.append(candidate)
# 选择最一致的响应
best_candidate = self.select_most_consistent(candidates)
# 使用FTTT优化
return self.fttt.infer(prompt + best_candidate)7. 实践指南
7.1 何时使用FTTT
FTTT适合以下场景:
- 推理质量不足:模型基本正确但表达不够清晰
- 有可用的反馈信号:评分模型、验证器等
- 允许一定的计算开销:每次推理允许额外2-5次生成
- 任务有明确的正确性标准:可以设计反馈模型
7.2 超参数设置
FTTT_HYPERPARAMETERS = {
"max_iterations": {
"default": 5,
"range": [3, 10],
"description": "最大迭代次数"
},
"early_stop_threshold": {
"default": 0.95,
"range": [0.8, 0.99],
"description": "早停阈值"
},
"learning_rate": {
"default": 0.01,
"range": [0.001, 0.1],
"description": "优化学习率"
},
"num_samples": {
"default": 8,
"range": [4, 32],
"description": "候选样本数量(用于Self-Consistency)"
}
}7.3 反馈模型设计
反馈模型是FTTT效果的关键:
class TaskSpecificFeedbackModel:
"""
任务特定的反馈模型设计指南
"""
@staticmethod
def for_math_problems(verifier):
"""数学问题反馈模型"""
def evaluate(prompt, response):
# 提取答案
extracted = extract_answer(response)
# 验证正确性
is_correct = verifier.verify(prompt, extracted)
return FeedbackResult(
score=1.0 if is_correct else 0.0,
is_good_enough=is_correct,
target=extracted
)
return evaluate
@staticmethod
def for_code_generation(executor):
"""代码生成反馈模型"""
def evaluate(prompt, response):
# 提取代码
code = extract_code(response)
# 执行测试用例
result = executor.run_tests(code)
return FeedbackResult(
score=result.pass_rate,
is_good_enough=result.pass_rate > 0.8,
target=result.expected_output
)
return evaluate
@staticmethod
def for_open_ended(judge_model):
"""开放式任务反馈模型"""
def evaluate(prompt, response):
# 使用LLM评判
judgment = judge_model.judge(prompt, response)
return FeedbackResult(
score=judgment.score,
is_good_enough=judgment.score > 0.8,
target=judgment.feedback
)
return evaluate8. 代码示例
8.1 完整使用示例
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# 1. 加载模型
model_name = "meta-llama/Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 2. 初始化反馈模型(以数学问题为例)
feedback_model = SimpleScoringFeedbackModel()
# 3. 初始化FTTT
fttt = FTTT(
model=model,
feedback_model=feedback_model,
max_iterations=5,
early_stop_threshold=0.95
)
# 4. 推理
test_prompts = [
"求 x^2 - 5x + 6 = 0 的解",
"一个三角形,边长分别为3, 4, 5,求其面积"
]
results = fttt.batch_infer(test_prompts, parallel=False)
for prompt, result in zip(test_prompts, results):
print(f"Prompt: {prompt}")
print(f"Result: {result}")
print("-" * 50)8.2 自定义反馈模型
class CustomFeedbackModel:
"""自定义反馈模型示例"""
def __init__(self, reward_model=None, rules=None):
self.reward_model = reward_model
self.rules = rules or []
def evaluate(self, prompt, response):
"""
自定义评估逻辑
"""
# 1. 规则检查
rule_score = self._apply_rules(response)
# 2. 奖励模型评分
if self.reward_model:
reward_score = self.reward_model.get_score(prompt, response)
else:
reward_score = 0.5
# 3. 综合评分
final_score = 0.3 * rule_score + 0.7 * reward_score
# 4. 生成反馈文本
feedback_text = self._generate_feedback(
prompt, response, rule_score, reward_score
)
return FeedbackResult(
score=final_score,
is_good_enough=final_score > 0.8,
target=feedback_text
)
def _apply_rules(self, response):
"""应用规则评分"""
score = 1.0
for rule in self.rules:
if not rule.check(response):
score *= rule.penalty
return score
def _generate_feedback(self, prompt, response, rule_score, reward_score):
"""生成反馈文本"""
# 根据评分生成针对性反馈
if rule_score < 0.5:
return "请检查答案的格式和完整性"
if reward_score < 0.5:
return "答案正确但表达不够清晰,请更详细地解释推理过程"
return "回答质量良好"参考
相关阅读
Footnotes
-
本文档基于FTTT(Feedback Test-Time Training)论文整理。相关论文发表在ICLR/NeurIPS 2025。 ↩