Seeing Through Circuits:Vision Transformer的忠实机械可解释性
1. 问题背景
1.1 机械可解释性的目标
机械可解释性旨在通过识别神经网络的电路(computational circuits)来理解其内部决策过程:
- 边级电路:定义组件间连接(边)
- 节点级电路:定义神经元激活
1.2 现有方法的局限性
在NLP领域的LLM中,边级电路已被成功识别。但在计算机视觉领域,现有方法仅考虑神经元级电路:
| 领域 | 电路类型 | 局限性 |
|---|---|---|
| NLP/LLM | 边级电路 | 已成功实现 |
| CV/ViT | 神经元级电路 | 仅告知”什么信息被编码”,但不能解释”信息如何通过复杂连接路由” |
1.3 核心挑战
视觉任务的特点:
- 空间结构:图像具有2D空间结构
- 多尺度特征:边缘、纹理、对象部分
- 注意力复杂:多头自注意力的交互模式
- 类别多样性:细粒度分类需要区分微妙差异
2. Vi-CD方法详解
2.1 核心思想
**自动视觉电路发现(Vi-CD)**方法扩展了NLP领域的边级电路发现技术:
- 边级电路识别:不仅识别激活的神经元,还识别它们之间的连接
- 类别特定分析:为每个类别发现专门的电路
- 行为纠正:通过操控电路实现对有害行为的纠正
2.2 方法框架
┌──────────────────────────────────────────────────────────────┐
│ Vi-CD Framework │
├──────────────────────────────────────────────────────────────┤
│ │
│ Vision Transformer (ViT) │
│ ┌────────────────────────────────────────────────────┐ │
│ │ Patch Embed → Transformer Layers → Classification │ │
│ └────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────┐ │
│ │ Edge-Level Circuit Discovery │ │
│ │ │ │
│ │ 1. Activation Analysis │ │
│ │ 2. Patch Attribution │ │
│ │ 3. Edge Importance Scoring │ │
│ │ 4. Circuit Extraction │ │
│ └────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────┐ │
│ │ Visual Circuits │ │
│ │ │ │
│ │ • Class-specific circuits │ │
│ │ • Typographic attack detection │ │
│ │ • Harmful behavior correction │ │
│ └────────────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────────────────┘
3. 技术细节
3.1 边级重要性评分
3.1.1 激活 patching
对于边 (从组件 到组件 的连接):
3.1.2 Patch级归因
Patch归因(Patch Attribution)识别对分类决策重要的图像区域:
def patch_attribution(model, image, target_class):
"""
Compute patch importance for target class.
Returns:
attribution_scores: Importance score per patch
"""
# Get model predictions
logits = model(image)
target_logit = logits[0, target_class]
# Compute gradient w.r.t. patches
target_logit.backward()
# Aggregate gradient magnitudes
grad = model.patch_embed.weight.grad
attribution = grad.abs().mean(dim=[1, 2, 3])
return attribution3.2 电路发现算法
3.2.1 边级注意力 patching
def edge_level_patch_intervention(model, image, layer_idx, edge):
"""
Perform edge-level patching intervention.
Args:
model: Vision Transformer
image: Input image
layer_idx: Layer to intervene
edge: Tuple (from_idx, to_idx) specifying the edge
from_idx, to_idx = edge
# Get original activations
with torch.no_grad():
original_activations = get_activations(model, image, layer_idx)
# Create intervention: zero out specific edge
def intervention_hook(module, input, output):
modified_output = output.clone()
# Modify the edge connection
modified_output[:, to_idx] = 0
return modified_output
# Register hook
hook = register_hook(model, layer_idx, intervention_hook)
# Forward pass with intervention
with_intervention = model(image)
# Remove hook
remove_hook(hook)
return with_intervention
def compute_edge_importance(model, image, target_class):
"""
Compute importance score for each edge.
"""
edge_importances = {}
# For each layer
for layer_idx in range(model.n_layers):
# For each edge in the layer
n_edges = get_num_edges(model, layer_idx)
for edge_idx in range(n_edges):
# Perform patching
logits_with = edge_level_patch_intervention(
model, image, layer_idx, edge_idx
)
# Compute importance as difference
original_score = logits_with[0, target_class].item()
# Patch edge to zero
logits_without = patch_edge_to_zero(...)
importance = original_score - logits_without[0, target_class].item()
edge_importances[(layer_idx, edge_idx)] = importance
return edge_importances3.3 类别特定电路
3.3.1 类别分离电路
不同类别的视觉概念需要不同的电路处理:
类别 "猫" 的电路:
┌─────────────────────────────────────────┐
│ 边缘检测 → 纹理分析 → 形状识别 → 分类 │
│ [Gabor filters] → [HOG] → [Shape] │
└─────────────────────────────────────────┘
类别 "狗" 的电路:
┌─────────────────────────────────────────┐
│ 边缘检测 → 纹理分析 → 毛发检测 → 分类 │
│ [Gabor filters] → [Texture] → [Fur] │
└─────────────────────────────────────────┘
3.3.2 电路可视化
def visualize_circuit(model, image, target_class, top_k_edges=50):
"""
Visualize the top-k edges in a circuit.
"""
# Compute edge importances
edge_importances = compute_edge_importance(
model, image, target_class
)
# Select top-k edges
top_edges = sorted(
edge_importances.items(),
key=lambda x: abs(x[1]),
reverse=True
)[:top_k_edges]
# Create visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
# Show original image
axes[0].imshow(image)
axes[0].set_title('Original Image')
axes[0].axis('off')
# Show circuit overlay
ax = axes[1]
ax.imshow(image)
# Draw edges
for (layer_idx, edge_idx), importance in top_edges:
# Get spatial locations
from_layer, from_patch = decode_edge_location(layer_idx, edge_idx)
to_layer, to_patch = decode_edge_location(layer_idx + 1, edge_idx)
# Draw edge with alpha based on importance
alpha = min(abs(importance) / max_importance, 1.0)
color = 'green' if importance > 0 else 'red'
# Get patch coordinates
x1, y1 = patch_to_coords(from_patch)
x2, y2 = patch_to_coords(to_patch)
ax.arrow(x1, y1, x2-x1, y2-y1,
alpha=alpha, color=color,
head_width=5, head_length=3)
ax.set_title(f'Circuit for Class {target_class}')
ax.axis('off')
return fig4. 应用场景
4.1 类别特定电路发现
4.1.1 细粒度分类
对于细粒度分类任务(如鸟类识别),Vi-CD能够发现:
- 细微纹理差异:羽毛图案的检测路径
- 形状变体:不同鸟喙形状的识别电路
- 颜色线索:特定颜色模式的检测
4.1.2 场景理解
对于场景分类,Vi-CD揭示了:
- 对象-背景分离:前景对象与背景的区分电路
- 空间关系:对象间空间关系的编码方式
- 纹理-形状平衡:不同场景类型依赖不同的特征组合
4.2 对抗攻击检测
4.2.1 排版攻击(Typographic Attacks)
Vi-CD能够识别排版攻击的电路:
排版攻击示例:
[正常图像] → 电路激活正常特征
[添加文字] → 电路错误激活文本检测模块
Vi-CD发现的关键边:
• 文本区域 → CLIP文本编码器
• 文本激活 → 错误分类输出
4.2.2 电路级防御
def circuit_level_defense(model, image):
"""
Use circuit analysis for adversarial defense.
"""
# Analyze circuit for text-like activation
text_edges = find_text_activation_edges(model, image)
if len(text_edges) > threshold:
# Suppress text-related edges
for edge in text_edges:
suppress_edge(model, edge)
# Re-classify with suppressed circuit
return model(image)
else:
return model(image)4.3 有害行为纠正
4.3.1 偏见识别
Vi-CD可以识别导致偏见的电路:
- 种族偏见:某些族裔面孔过度激活特定模式
- 性别偏见:性别刻板印象的视觉线索
- 文化偏见:特定文化符号的错误关联
4.3.2 电路级干预
def circuit_level_intervention(model, image, harmful_circuit):
"""
Correct harmful behavior by intervening on circuits.
"""
# Get activations through harmful circuit
harmful_activations = trace_circuit(model, image, harmful_circuit)
# Compute correction signal
correction = -harmful_activations * intervention_strength
# Apply correction at circuit boundaries
intervention_hook = lambda module, input, output: output + correction
# Register intervention
register_permanent_hook(model, harmful_circuit.target_layer,
intervention_hook)
# Verify correction
corrected_output = model(image)
return corrected_output5. 实验结果
5.1 电路质量评估
| 任务 | 基准方法 | Vi-CD | 提升 |
|---|---|---|---|
| 类别分类 | 神经元级 | 边级 | +12.3% |
| OOD检测 | 全部组件 | 核心边 | -45%组件 |
5.2 对抗鲁棒性
| 攻击类型 | 无防御 | 神经元防御 | 边级防御 |
|---|---|---|---|
| FGSM | 42.1% | 65.3% | 78.9% |
| PGD | 28.7% | 51.2% | 69.4% |
| 排版攻击 | 35.6% | 62.1% | 81.2% |
5.3 可解释性评估
| 指标 | 神经元级 | 边级 |
|---|---|---|
| 因果清晰度 | 0.45 | 0.82 |
| 干预有效性 | 0.52 | 0.88 |
| 人类可理解性 | 0.61 | 0.76 |
6. 与NLP电路发现的对比
6.1 方法迁移
| 方面 | NLP电路 | CV电路 |
|---|---|---|
| 基本电路发现 | Activation patching | Activation patching |
| 重要性度量 | Logit difference | Attention contribution |
| 边粒度 | Token-to-token | Patch-to-patch |
| 可视化 | Attention patterns | Spatial overlays |
6.2 视觉特有的挑战
- 空间连续性:图像Patch间的关系更复杂
- 多尺度特征:需要处理不同尺度的模式
- 平移不变性:期望对平移具有鲁棒性
7. 总结与展望
7.1 核心贡献
- 边级电路:首次在Vision Transformer中实现边级电路发现
- Vi-CD方法:自动视觉电路发现的系统方法
- 应用扩展:从分类到对抗检测到行为纠正
7.2 局限性
- 计算成本较高
- 对复杂场景的处理能力有限
- 需要人工验证电路语义
7.3 未来方向
- 多模态电路发现
- 视频理解中的时空电路
- 实时电路分析