抽取式问答电路分析
1. 概述
抽取式问答(Extractive Question Answering)是NLP中的核心任务,需要模型从给定上下文中提取答案。论文《On Mechanistic Circuits for Extractive Question-Answering》首次系统性地分析了这一任务的电路机制。
2. 任务定义
2.1 形式化
给定:
- 问题
- 上下文
- 答案边界 其中
目标:预测答案范围 使得答案 是 的正确答案。
2.2 模型架构
class ExtractiveQA:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def forward(self, question, context):
"""
前向传播
"""
# 拼接问题和上下文
tokens = self.tokenizer(question, context, return_tensors='pt')
# 获取隐藏状态
outputs = self.model(**tokens)
hidden_states = outputs.last_hidden_state
# 预测开始和结束位置
start_logits = self.start_head(hidden_states)
end_logits = self.end_head(hidden_states)
return start_logits, end_logits3. 电路组件识别
3.1 关键组件
通过激活修补识别出以下关键组件:
| 组件类型 | 层/头位置 | 功能 |
|---|---|---|
| 问题编码器 | (1-3, *) | 理解问题语义 |
| 上下文选择器 | (4-5, 2) | 定位相关上下文 |
| 答案边界检测器 | (6-7, 5) | 识别答案边界 |
| 复制机制 | (8, 3) | 从上下文复制答案 |
3.2 组件功能详解
# 识别的电路组件
CIRCUIT_COMPONENTS = {
# 问题理解组件
'question_encoder': {
'layers': list(range(1, 4)),
'type': 'mlp_attention',
'function': 'question_semantic_parsing'
},
# 上下文选择组件
'context_selector': {
'layer': 4,
'head': 2,
'type': 'attention_head',
'function': 'query_context_matching'
},
# 边界检测组件
'boundary_detector': {
'layers': [6, 7],
'head': 5,
'type': 'attention_head',
'function': 'span_boundary_detection'
},
# 复制组件
'copy_mechanism': {
'layer': 8,
'head': 3,
'type': 'attention_head',
'function': 'token_copying'
}
}4. 复制机制分析
4.1 复制头的发现
复制机制(Copy Mechanism)是问答电路的核心,负责将答案从上下文复制到输出。
class CopyMechanism:
def __init__(self, model):
self.model = model
def analyze_copy_head(self, layer=8, head=3):
"""
分析复制头的工作机制
"""
# 获取注意力权重
attention_pattern = self.get_attention(layer, head)
# 分析注意力分布
# 复制头应该高度关注上下文中的答案位置
answer_attention = self.compute_answer_attention(attention_pattern)
return {
'mean_attention_to_answer': answer_attention.mean(),
'max_attention_to_answer': answer_attention.max(),
'attention_concentration': self.compute_concentration(answer_attention)
}4.2 复制条件
复制机制的条件:
P(\text{copy } c_i | Q, C) = \text{softmax}(W \cdot h_i + b) \cdot I(\text{is_in_context}(c_i))def copy_probability(context_token, question_representation):
"""
计算复制概率
"""
# 计算上下文token与问题的相关性
relevance = context_token @ question_representation.T
# 计算复制分数
copy_score = self.copy_head.weight @ context_token + self.copy_head.bias
# 组合分数
combined_score = alpha * relevance + beta * copy_score
# 归一化
probs = F.softmax(combined_score)
return probs4.3 归纳头的作用
在问答任务中,Induction Head 发挥关键作用:
def induction_head_in_qa(layer1=5, layer2=7):
"""
分析Induction Head在QA中的作用
"""
# Induction Head 1:定位相同pattern
pattern_head = self.get_attention(layer1, pattern_head_idx)
# Induction Head 2:复制匹配内容
copy_head = self.get_attention(layer2, copy_head_idx)
# 分析信息流
information_flow = self.trace_information_flow(
from_layer=layer1,
from_head=pattern_head_idx,
to_layer=layer2,
to_head=copy_head_idx
)
return information_flow5. 边界检测机制
5.1 边界检测器架构
class BoundaryDetector:
def __init__(self, model):
self.model = model
def detect_boundaries(self, hidden_states):
"""
边界检测
"""
# 开始边界检测
start_scores = self.start_head(hidden_states)
# 结束边界检测
end_scores = self.end_head(hidden_states)
# 约束:结束位置必须大于开始位置
valid_pairs = self.constrain_boundaries(start_scores, end_scores)
return valid_pairs
def constrain_boundaries(self, start_scores, end_scores):
"""
边界约束
"""
batch_size, seq_len = start_scores.shape
# 初始化
best_starts = torch.zeros(batch_size, dtype=torch.long)
best_ends = torch.zeros(batch_size, dtype=torch.long)
best_scores = torch.zeros(batch_size)
for i in range(seq_len):
for j in range(i, seq_len):
score = start_scores[:, i] + end_scores[:, j]
mask = score > best_scores
best_starts[mask] = i
best_ends[mask] = j
best_scores[mask] = score[mask]
return best_starts, best_ends5.2 边界检测的电路机制
def boundary_detection_circuit(context_repr, question_repr):
"""
边界检测电路机制
"""
# 问题感知:识别问题询问的是什么
question_aware = fuse_question_context(
context_repr,
question_repr,
fusion_type='cross_attention'
)
# 边界评分:评估每个位置作为边界的可能性
boundary_scores = {}
for position in range(seq_len):
# 开始边界评分
start_score = score_start_boundary(
question_aware[position],
context_repr,
position
)
# 结束边界评分
end_score = score_end_boundary(
question_aware[position],
context_repr,
position
)
boundary_scores[position] = {
'start': start_score,
'end': end_score
}
return boundary_scores6. 因果追溯分析
6.1 追溯实验设计
class QACausalTracer:
def __init__(self, model):
self.model = model
def full_trace(self, question, context, answer_span):
"""
完整因果追溯
"""
# 干净输入
clean_tokens = self.tokenize(question, context)
# 损坏输入(答案位置置零)
corrupt_tokens = self.corrupt_answer(clean_tokens, answer_span)
# 逐层追溯
effects = {}
for layer in range(self.model.config.n_layers):
for head in range(self.model.config.n_heads):
effect = self.measure_effect(
clean_tokens,
corrupt_tokens,
layer,
head
)
effects[(layer, head)] = effect
return effects
def measure_effect(self, clean, corrupt, layer, head):
"""
测量特定注意力头的效应
"""
# 修补该头
patched_logits = self.patch_head(clean, corrupt, layer, head)
# 计算答案预测概率变化
clean_prob = F.softmax(clean_logits)[..., answer_position]
patched_prob = F.softmax(patched_logits)[..., answer_position]
return clean_prob - patched_prob6.2 追溯结果
因果追溯显示:
层 头 因果效应 功能
────────────────────────────────────
1 * 0.12 问题编码
4 2 0.45 上下文选择 ★★★
5 7 0.38 模式匹配
6 5 0.52 边界检测 ★★★
7 5 0.48 边界检测 ★★
8 3 0.61 答案复制 ★★★
7. 电路验证
7.1 完整性验证
def verify_comprehensiveness(circuit, test_cases):
"""
验证电路完整性
"""
# 完整模型性能
full_performance = evaluate_model(circuit.model, test_cases)
# 仅使用电路组件
circuit_performance = evaluate_circuit(circuit, test_cases)
# 计算完整性分数
comprehensiveness = circuit_performance / full_performance
return {
'full_performance': full_performance,
'circuit_performance': circuit_performance,
'comprehensiveness': comprehensiveness
}7.2 充分性验证
def verify_sufficiency(circuit, test_cases):
"""
验证电路充分性
"""
# 仅使用电路组件
circuit_outputs = [circuit.compute(case.input) for case in test_cases]
# 与完整模型对比
full_outputs = [circuit.model(case.input) for case in test_cases]
# 计算一致率
agreement = compute_agreement(circuit_outputs, full_outputs)
return {
'circuit_outputs': circuit_outputs,
'full_outputs': full_outputs,
'agreement_rate': agreement
}8. 错误分析
8.1 错误类型分布
| 错误类型 | 比例 | 主要原因 |
|---|---|---|
| 边界错误 | 35% | 边界检测器不准确 |
| 复制失败 | 28% | 复制头注意力分散 |
| 上下文混淆 | 22% | 选错相关上下文 |
| 问题误解 | 15% | 问题编码器理解错误 |
8.2 错误案例分析
def error_case_analysis(circuit, error_cases):
"""
错误案例深入分析
"""
for case in error_cases:
# 识别失败原因
failure_reasons = []
# 检查问题编码
if circuit.question_encoder.is_weak(case.question):
failure_reasons.append('weak_question_encoding')
# 检查上下文选择
if circuit.context_selector.is_wrong(case.context):
failure_reasons.append('wrong_context_selection')
# 检查边界检测
if circuit.boundary_detector.is_inaccurate(case.answer):
failure_reasons.append('inaccurate_boundary')
# 检查复制机制
if circuit.copy_mechanism.is_failing(case.answer):
failure_reasons.append('copy_failure')
print(f"Case: {case.id}, Reasons: {failure_reasons}")9. 与其他任务的电路对比
9.1 与 Induction Head 的对比
| 维度 | 抽取式问答电路 | Induction Head |
|---|---|---|
| 目标 | 提取答案 | 复制匹配内容 |
| 位置匹配 | 语义匹配 | 精确匹配 |
| 边界约束 | 显式边界检测 | 隐式边界 |
9.2 与命名实体识别的对比
问答电路与NER电路共享以下组件:
- 边界检测机制
- 上下文编码器
差异在于:
- 问题理解层(问答特有)
- 复制机制(问答特有)
10. 总结
10.1 主要发现
- 问答电路由多个专用组件组成
- 复制机制是答案生成的核心
- 边界检测与上下文选择协同工作
- 存在层次化的信息处理流程
10.2 应用价值
- 模型调试:定位错误来源
- 模型改进:针对性优化特定组件
- 知识迁移:将问答电路迁移到其他任务