抽取式问答电路分析

1. 概述

抽取式问答(Extractive Question Answering)是NLP中的核心任务,需要模型从给定上下文中提取答案。论文《On Mechanistic Circuits for Extractive Question-Answering》首次系统性地分析了这一任务的电路机制。

2. 任务定义

2.1 形式化

给定:

  • 问题
  • 上下文
  • 答案边界 其中

目标:预测答案范围 使得答案 的正确答案。

2.2 模型架构

class ExtractiveQA:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def forward(self, question, context):
        """
        前向传播
        """
        # 拼接问题和上下文
        tokens = self.tokenizer(question, context, return_tensors='pt')
        
        # 获取隐藏状态
        outputs = self.model(**tokens)
        hidden_states = outputs.last_hidden_state
        
        # 预测开始和结束位置
        start_logits = self.start_head(hidden_states)
        end_logits = self.end_head(hidden_states)
        
        return start_logits, end_logits

3. 电路组件识别

3.1 关键组件

通过激活修补识别出以下关键组件:

组件类型层/头位置功能
问题编码器(1-3, *)理解问题语义
上下文选择器(4-5, 2)定位相关上下文
答案边界检测器(6-7, 5)识别答案边界
复制机制(8, 3)从上下文复制答案

3.2 组件功能详解

# 识别的电路组件
CIRCUIT_COMPONENTS = {
    # 问题理解组件
    'question_encoder': {
        'layers': list(range(1, 4)),
        'type': 'mlp_attention',
        'function': 'question_semantic_parsing'
    },
    
    # 上下文选择组件
    'context_selector': {
        'layer': 4,
        'head': 2,
        'type': 'attention_head',
        'function': 'query_context_matching'
    },
    
    # 边界检测组件
    'boundary_detector': {
        'layers': [6, 7],
        'head': 5,
        'type': 'attention_head',
        'function': 'span_boundary_detection'
    },
    
    # 复制组件
    'copy_mechanism': {
        'layer': 8,
        'head': 3,
        'type': 'attention_head',
        'function': 'token_copying'
    }
}

4. 复制机制分析

4.1 复制头的发现

复制机制(Copy Mechanism)是问答电路的核心,负责将答案从上下文复制到输出。

class CopyMechanism:
    def __init__(self, model):
        self.model = model
    
    def analyze_copy_head(self, layer=8, head=3):
        """
        分析复制头的工作机制
        """
        # 获取注意力权重
        attention_pattern = self.get_attention(layer, head)
        
        # 分析注意力分布
        # 复制头应该高度关注上下文中的答案位置
        answer_attention = self.compute_answer_attention(attention_pattern)
        
        return {
            'mean_attention_to_answer': answer_attention.mean(),
            'max_attention_to_answer': answer_attention.max(),
            'attention_concentration': self.compute_concentration(answer_attention)
        }

4.2 复制条件

复制机制的条件:

P(\text{copy } c_i | Q, C) = \text{softmax}(W \cdot h_i + b) \cdot I(\text{is_in_context}(c_i))
def copy_probability(context_token, question_representation):
    """
    计算复制概率
    """
    # 计算上下文token与问题的相关性
    relevance = context_token @ question_representation.T
    
    # 计算复制分数
    copy_score = self.copy_head.weight @ context_token + self.copy_head.bias
    
    # 组合分数
    combined_score = alpha * relevance + beta * copy_score
    
    # 归一化
    probs = F.softmax(combined_score)
    
    return probs

4.3 归纳头的作用

在问答任务中,Induction Head 发挥关键作用:

def induction_head_in_qa(layer1=5, layer2=7):
    """
    分析Induction Head在QA中的作用
    """
    # Induction Head 1:定位相同pattern
    pattern_head = self.get_attention(layer1, pattern_head_idx)
    
    # Induction Head 2:复制匹配内容
    copy_head = self.get_attention(layer2, copy_head_idx)
    
    # 分析信息流
    information_flow = self.trace_information_flow(
        from_layer=layer1,
        from_head=pattern_head_idx,
        to_layer=layer2,
        to_head=copy_head_idx
    )
    
    return information_flow

5. 边界检测机制

5.1 边界检测器架构

class BoundaryDetector:
    def __init__(self, model):
        self.model = model
    
    def detect_boundaries(self, hidden_states):
        """
        边界检测
        """
        # 开始边界检测
        start_scores = self.start_head(hidden_states)
        
        # 结束边界检测
        end_scores = self.end_head(hidden_states)
        
        # 约束:结束位置必须大于开始位置
        valid_pairs = self.constrain_boundaries(start_scores, end_scores)
        
        return valid_pairs
    
    def constrain_boundaries(self, start_scores, end_scores):
        """
        边界约束
        """
        batch_size, seq_len = start_scores.shape
        
        # 初始化
        best_starts = torch.zeros(batch_size, dtype=torch.long)
        best_ends = torch.zeros(batch_size, dtype=torch.long)
        best_scores = torch.zeros(batch_size)
        
        for i in range(seq_len):
            for j in range(i, seq_len):
                score = start_scores[:, i] + end_scores[:, j]
                mask = score > best_scores
                best_starts[mask] = i
                best_ends[mask] = j
                best_scores[mask] = score[mask]
        
        return best_starts, best_ends

5.2 边界检测的电路机制

def boundary_detection_circuit(context_repr, question_repr):
    """
    边界检测电路机制
    """
    # 问题感知:识别问题询问的是什么
    question_aware = fuse_question_context(
        context_repr,
        question_repr,
        fusion_type='cross_attention'
    )
    
    # 边界评分:评估每个位置作为边界的可能性
    boundary_scores = {}
    for position in range(seq_len):
        # 开始边界评分
        start_score = score_start_boundary(
            question_aware[position],
            context_repr,
            position
        )
        
        # 结束边界评分
        end_score = score_end_boundary(
            question_aware[position],
            context_repr,
            position
        )
        
        boundary_scores[position] = {
            'start': start_score,
            'end': end_score
        }
    
    return boundary_scores

6. 因果追溯分析

6.1 追溯实验设计

class QACausalTracer:
    def __init__(self, model):
        self.model = model
    
    def full_trace(self, question, context, answer_span):
        """
        完整因果追溯
        """
        # 干净输入
        clean_tokens = self.tokenize(question, context)
        
        # 损坏输入(答案位置置零)
        corrupt_tokens = self.corrupt_answer(clean_tokens, answer_span)
        
        # 逐层追溯
        effects = {}
        for layer in range(self.model.config.n_layers):
            for head in range(self.model.config.n_heads):
                effect = self.measure_effect(
                    clean_tokens,
                    corrupt_tokens,
                    layer,
                    head
                )
                effects[(layer, head)] = effect
        
        return effects
    
    def measure_effect(self, clean, corrupt, layer, head):
        """
        测量特定注意力头的效应
        """
        # 修补该头
        patched_logits = self.patch_head(clean, corrupt, layer, head)
        
        # 计算答案预测概率变化
        clean_prob = F.softmax(clean_logits)[..., answer_position]
        patched_prob = F.softmax(patched_logits)[..., answer_position]
        
        return clean_prob - patched_prob

6.2 追溯结果

因果追溯显示:

层     头    因果效应    功能
────────────────────────────────────
1      *     0.12       问题编码
4      2     0.45       上下文选择 ★★★
5      7     0.38       模式匹配
6      5     0.52       边界检测 ★★★
7      5     0.48       边界检测 ★★
8      3     0.61       答案复制 ★★★

7. 电路验证

7.1 完整性验证

def verify_comprehensiveness(circuit, test_cases):
    """
    验证电路完整性
    """
    # 完整模型性能
    full_performance = evaluate_model(circuit.model, test_cases)
    
    # 仅使用电路组件
    circuit_performance = evaluate_circuit(circuit, test_cases)
    
    # 计算完整性分数
    comprehensiveness = circuit_performance / full_performance
    
    return {
        'full_performance': full_performance,
        'circuit_performance': circuit_performance,
        'comprehensiveness': comprehensiveness
    }

7.2 充分性验证

def verify_sufficiency(circuit, test_cases):
    """
    验证电路充分性
    """
    # 仅使用电路组件
    circuit_outputs = [circuit.compute(case.input) for case in test_cases]
    
    # 与完整模型对比
    full_outputs = [circuit.model(case.input) for case in test_cases]
    
    # 计算一致率
    agreement = compute_agreement(circuit_outputs, full_outputs)
    
    return {
        'circuit_outputs': circuit_outputs,
        'full_outputs': full_outputs,
        'agreement_rate': agreement
    }

8. 错误分析

8.1 错误类型分布

错误类型比例主要原因
边界错误35%边界检测器不准确
复制失败28%复制头注意力分散
上下文混淆22%选错相关上下文
问题误解15%问题编码器理解错误

8.2 错误案例分析

def error_case_analysis(circuit, error_cases):
    """
    错误案例深入分析
    """
    for case in error_cases:
        # 识别失败原因
        failure_reasons = []
        
        # 检查问题编码
        if circuit.question_encoder.is_weak(case.question):
            failure_reasons.append('weak_question_encoding')
        
        # 检查上下文选择
        if circuit.context_selector.is_wrong(case.context):
            failure_reasons.append('wrong_context_selection')
        
        # 检查边界检测
        if circuit.boundary_detector.is_inaccurate(case.answer):
            failure_reasons.append('inaccurate_boundary')
        
        # 检查复制机制
        if circuit.copy_mechanism.is_failing(case.answer):
            failure_reasons.append('copy_failure')
        
        print(f"Case: {case.id}, Reasons: {failure_reasons}")

9. 与其他任务的电路对比

9.1 与 Induction Head 的对比

维度抽取式问答电路Induction Head
目标提取答案复制匹配内容
位置匹配语义匹配精确匹配
边界约束显式边界检测隐式边界

9.2 与命名实体识别的对比

问答电路与NER电路共享以下组件:

  • 边界检测机制
  • 上下文编码器

差异在于:

  • 问题理解层(问答特有)
  • 复制机制(问答特有)

10. 总结

10.1 主要发现

  1. 问答电路由多个专用组件组成
  2. 复制机制是答案生成的核心
  3. 边界检测与上下文选择协同工作
  4. 存在层次化的信息处理流程

10.2 应用价值

  • 模型调试:定位错误来源
  • 模型改进:针对性优化特定组件
  • 知识迁移:将问答电路迁移到其他任务

参考资料