Seeing Through Circuits:Vision Transformer的忠实机械可解释性

1. 问题背景

1.1 机械可解释性的目标

机械可解释性旨在通过识别神经网络的电路(computational circuits)来理解其内部决策过程:

  • 边级电路:定义组件间连接(边)
  • 节点级电路:定义神经元激活

1.2 现有方法的局限性

在NLP领域的LLM中,边级电路已被成功识别。但在计算机视觉领域,现有方法仅考虑神经元级电路

领域电路类型局限性
NLP/LLM边级电路已成功实现
CV/ViT神经元级电路仅告知”什么信息被编码”,但不能解释”信息如何通过复杂连接路由”

1.3 核心挑战

视觉任务的特点

  1. 空间结构:图像具有2D空间结构
  2. 多尺度特征:边缘、纹理、对象部分
  3. 注意力复杂:多头自注意力的交互模式
  4. 类别多样性:细粒度分类需要区分微妙差异

2. Vi-CD方法详解

2.1 核心思想

**自动视觉电路发现(Vi-CD)**方法扩展了NLP领域的边级电路发现技术:

  1. 边级电路识别:不仅识别激活的神经元,还识别它们之间的连接
  2. 类别特定分析:为每个类别发现专门的电路
  3. 行为纠正:通过操控电路实现对有害行为的纠正

2.2 方法框架

┌──────────────────────────────────────────────────────────────┐
│                      Vi-CD Framework                          │
├──────────────────────────────────────────────────────────────┤
│                                                               │
│   Vision Transformer (ViT)                                    │
│   ┌────────────────────────────────────────────────────┐    │
│   │  Patch Embed → Transformer Layers → Classification │    │
│   └────────────────────────────────────────────────────┘    │
│                         │                                    │
│                         ▼                                    │
│   ┌────────────────────────────────────────────────────┐    │
│   │         Edge-Level Circuit Discovery                │    │
│   │                                                     │    │
│   │   1. Activation Analysis                            │    │
│   │   2. Patch Attribution                              │    │
│   │   3. Edge Importance Scoring                       │    │
│   │   4. Circuit Extraction                            │    │
│   └────────────────────────────────────────────────────┘    │
│                         │                                    │
│                         ▼                                    │
│   ┌────────────────────────────────────────────────────┐    │
│   │              Visual Circuits                        │    │
│   │                                                     │    │
│   │   • Class-specific circuits                        │    │
│   │   • Typographic attack detection                  │    │
│   │   • Harmful behavior correction                   │    │
│   └────────────────────────────────────────────────────┘    │
└──────────────────────────────────────────────────────────────┘

3. 技术细节

3.1 边级重要性评分

3.1.1 激活 patching

对于边 (从组件 到组件 的连接):

3.1.2 Patch级归因

Patch归因(Patch Attribution)识别对分类决策重要的图像区域:

def patch_attribution(model, image, target_class):
    """
    Compute patch importance for target class.
    
    Returns:
        attribution_scores: Importance score per patch
    """
    # Get model predictions
    logits = model(image)
    target_logit = logits[0, target_class]
    
    # Compute gradient w.r.t. patches
    target_logit.backward()
    
    # Aggregate gradient magnitudes
    grad = model.patch_embed.weight.grad
    attribution = grad.abs().mean(dim=[1, 2, 3])
    
    return attribution

3.2 电路发现算法

3.2.1 边级注意力 patching

def edge_level_patch_intervention(model, image, layer_idx, edge):
    """
    Perform edge-level patching intervention.
    
    Args:
        model: Vision Transformer
        image: Input image
        layer_idx: Layer to intervene
        edge: Tuple (from_idx, to_idx) specifying the edge
    
    from_idx, to_idx = edge
    
    # Get original activations
    with torch.no_grad():
        original_activations = get_activations(model, image, layer_idx)
    
    # Create intervention: zero out specific edge
    def intervention_hook(module, input, output):
        modified_output = output.clone()
        # Modify the edge connection
        modified_output[:, to_idx] = 0
        return modified_output
    
    # Register hook
    hook = register_hook(model, layer_idx, intervention_hook)
    
    # Forward pass with intervention
    with_intervention = model(image)
    
    # Remove hook
    remove_hook(hook)
    
    return with_intervention
 
 
def compute_edge_importance(model, image, target_class):
    """
    Compute importance score for each edge.
    """
    edge_importances = {}
    
    # For each layer
    for layer_idx in range(model.n_layers):
        # For each edge in the layer
        n_edges = get_num_edges(model, layer_idx)
        
        for edge_idx in range(n_edges):
            # Perform patching
            logits_with = edge_level_patch_intervention(
                model, image, layer_idx, edge_idx
            )
            
            # Compute importance as difference
            original_score = logits_with[0, target_class].item()
            
            # Patch edge to zero
            logits_without = patch_edge_to_zero(...)
            
            importance = original_score - logits_without[0, target_class].item()
            
            edge_importances[(layer_idx, edge_idx)] = importance
    
    return edge_importances

3.3 类别特定电路

3.3.1 类别分离电路

不同类别的视觉概念需要不同的电路处理:

类别 "猫" 的电路:
┌─────────────────────────────────────────┐
│  边缘检测 → 纹理分析 → 形状识别 → 分类  │
│  [Gabor filters] → [HOG] → [Shape]     │
└─────────────────────────────────────────┘

类别 "狗" 的电路:
┌─────────────────────────────────────────┐
│  边缘检测 → 纹理分析 → 毛发检测 → 分类  │
│  [Gabor filters] → [Texture] → [Fur]   │
└─────────────────────────────────────────┘

3.3.2 电路可视化

def visualize_circuit(model, image, target_class, top_k_edges=50):
    """
    Visualize the top-k edges in a circuit.
    """
    # Compute edge importances
    edge_importances = compute_edge_importance(
        model, image, target_class
    )
    
    # Select top-k edges
    top_edges = sorted(
        edge_importances.items(),
        key=lambda x: abs(x[1]),
        reverse=True
    )[:top_k_edges]
    
    # Create visualization
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    # Show original image
    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # Show circuit overlay
    ax = axes[1]
    ax.imshow(image)
    
    # Draw edges
    for (layer_idx, edge_idx), importance in top_edges:
        # Get spatial locations
        from_layer, from_patch = decode_edge_location(layer_idx, edge_idx)
        to_layer, to_patch = decode_edge_location(layer_idx + 1, edge_idx)
        
        # Draw edge with alpha based on importance
        alpha = min(abs(importance) / max_importance, 1.0)
        color = 'green' if importance > 0 else 'red'
        
        # Get patch coordinates
        x1, y1 = patch_to_coords(from_patch)
        x2, y2 = patch_to_coords(to_patch)
        
        ax.arrow(x1, y1, x2-x1, y2-y1, 
                 alpha=alpha, color=color, 
                 head_width=5, head_length=3)
    
    ax.set_title(f'Circuit for Class {target_class}')
    ax.axis('off')
    
    return fig

4. 应用场景

4.1 类别特定电路发现

4.1.1 细粒度分类

对于细粒度分类任务(如鸟类识别),Vi-CD能够发现:

  • 细微纹理差异:羽毛图案的检测路径
  • 形状变体:不同鸟喙形状的识别电路
  • 颜色线索:特定颜色模式的检测

4.1.2 场景理解

对于场景分类,Vi-CD揭示了:

  • 对象-背景分离:前景对象与背景的区分电路
  • 空间关系:对象间空间关系的编码方式
  • 纹理-形状平衡:不同场景类型依赖不同的特征组合

4.2 对抗攻击检测

4.2.1 排版攻击(Typographic Attacks)

Vi-CD能够识别排版攻击的电路:

排版攻击示例:
[正常图像] → 电路激活正常特征
[添加文字] → 电路错误激活文本检测模块

Vi-CD发现的关键边:
• 文本区域 → CLIP文本编码器
• 文本激活 → 错误分类输出

4.2.2 电路级防御

def circuit_level_defense(model, image):
    """
    Use circuit analysis for adversarial defense.
    """
    # Analyze circuit for text-like activation
    text_edges = find_text_activation_edges(model, image)
    
    if len(text_edges) > threshold:
        # Suppress text-related edges
        for edge in text_edges:
            suppress_edge(model, edge)
        
        # Re-classify with suppressed circuit
        return model(image)
    else:
        return model(image)

4.3 有害行为纠正

4.3.1 偏见识别

Vi-CD可以识别导致偏见的电路:

  • 种族偏见:某些族裔面孔过度激活特定模式
  • 性别偏见:性别刻板印象的视觉线索
  • 文化偏见:特定文化符号的错误关联

4.3.2 电路级干预

def circuit_level_intervention(model, image, harmful_circuit):
    """
    Correct harmful behavior by intervening on circuits.
    """
    # Get activations through harmful circuit
    harmful_activations = trace_circuit(model, image, harmful_circuit)
    
    # Compute correction signal
    correction = -harmful_activations * intervention_strength
    
    # Apply correction at circuit boundaries
    intervention_hook = lambda module, input, output: output + correction
    
    # Register intervention
    register_permanent_hook(model, harmful_circuit.target_layer, 
                           intervention_hook)
    
    # Verify correction
    corrected_output = model(image)
    
    return corrected_output

5. 实验结果

5.1 电路质量评估

任务基准方法Vi-CD提升
类别分类神经元级边级+12.3%
OOD检测全部组件核心边-45%组件

5.2 对抗鲁棒性

攻击类型无防御神经元防御边级防御
FGSM42.1%65.3%78.9%
PGD28.7%51.2%69.4%
排版攻击35.6%62.1%81.2%

5.3 可解释性评估

指标神经元级边级
因果清晰度0.450.82
干预有效性0.520.88
人类可理解性0.610.76

6. 与NLP电路发现的对比

6.1 方法迁移

方面NLP电路CV电路
基本电路发现Activation patchingActivation patching
重要性度量Logit differenceAttention contribution
边粒度Token-to-tokenPatch-to-patch
可视化Attention patternsSpatial overlays

6.2 视觉特有的挑战

  1. 空间连续性:图像Patch间的关系更复杂
  2. 多尺度特征:需要处理不同尺度的模式
  3. 平移不变性:期望对平移具有鲁棒性

7. 总结与展望

7.1 核心贡献

  1. 边级电路:首次在Vision Transformer中实现边级电路发现
  2. Vi-CD方法:自动视觉电路发现的系统方法
  3. 应用扩展:从分类到对抗检测到行为纠正

7.2 局限性

  • 计算成本较高
  • 对复杂场景的处理能力有限
  • 需要人工验证电路语义

7.3 未来方向

  • 多模态电路发现
  • 视频理解中的时空电路
  • 实时电路分析

参考资料