Vision Transformer 对抗鲁棒性
概述
Vision Transformer(ViT)已成为计算机视觉的主流架构。与传统 CNN 相比,ViT 在对抗鲁棒性方面表现出独特的特性:既有优势(如对某些攻击更鲁棒),也有独特的脆弱性(如注意力机制易被攻击利用)。1
ViT 与 CNN 鲁棒性对比
结构差异分析
| 特性 | CNN | Vision Transformer |
|---|---|---|
| 感受野 | 局部卷积 | 全局注意力 |
| 特征提取 | 层级局部特征 | 全局均匀特征 |
| 位置编码 | 隐式学习 | 显式嵌入 |
| 注意力机制 | 无 | Query-Key-Value |
| 对抗脆弱性 | 纹理依赖 | 语义依赖 |
鲁棒性实验对比
def compare_robustness(cnn_model, vit_model, test_loader, epsilon=8/255):
"""
对比 CNN 和 ViT 的对抗鲁棒性
"""
results = {
'cnn_clean': [], 'vit_clean': [],
'cnn_robust': [], 'vit_robust': [],
'cnn_transfer': [], 'vit_transfer': []
}
for images, labels in test_loader:
images = images.to(device)
# 干净准确率
with torch.no_grad():
cnn_clean = cnn_model(images).argmax(1)
vit_clean = vit_model(images).argmax(1)
results['cnn_clean'].extend((cnn_clean == labels).cpu().numpy())
results['vit_clean'].extend((vit_clean == labels).cpu().numpy())
# CNN 对抗样本攻击 ViT
adv_cnn = pgd_attack(cnn_model, images, labels, epsilon)
with torch.no_grad():
vit_robust_cnn = vit_model(adv_cnn).argmax(1)
results['vit_transfer'].extend((vit_robust_cnn != labels).cpu().numpy())
# ViT 对抗样本攻击 CNN
adv_vit = pgd_attack(vit_model, images, labels, epsilon)
with torch.no_grad():
cnn_robust_vit = cnn_model(adv_vit).argmax(1)
results['cnn_transfer'].extend((cnn_robust_vit != labels).cpu().numpy())
# 自攻击
adv_cnn = pgd_attack(cnn_model, images, labels, epsilon)
adv_vit = pgd_attack(vit_model, images, labels, epsilon)
with torch.no_grad():
results['cnn_robust'].extend(
(cnn_model(adv_cnn).argmax(1) == labels).cpu().numpy()
)
results['vit_robust'].extend(
(vit_model(adv_vit).argmax(1) == labels).cpu().numpy()
)
return {k: np.mean(v) for k, v in results.items()}ViT 的独特脆弱性
注意力层级脆弱性
最新的研究发现,ViT 在不同层表现出不同的脆弱性模式:1
def analyze_attention_vulnerability(vit_model, x, epsilon=8/255):
"""
分析 ViT 各层的注意力脆弱性
"""
x_adv = pgd_attack(vit_model, x, y_true=None, epsilon=epsilon)
# 获取各层注意力
attentions = []
def hook_fn(module, input, output):
attentions.append(output)
hooks = []
for block in vit_model.blocks:
hooks.append(block.attn.register_forward_hook(hook_fn))
with torch.no_grad():
_ = vit_model(x)
_ = vit_model(x_adv)
for hook in hooks:
hook.remove()
# 分析注意力变化
vulnerability_scores = []
for i, (attn_clean, attn_adv) in enumerate(zip(attentions[::2], attentions[1::2])):
# 计算注意力分布的变化
diff = (attn_clean - attn_adv).abs().mean()
vulnerability_scores.append(diff.item())
return vulnerability_scores关键发现:
- 浅层注意力变化较小
- 深层注意力高度集中于少数 token
- 这些关键 token 的扰动会放大并传播
位置编码脆弱性
def position_encoding_attack(vit_model, x, epsilon=0.1):
"""
针对 ViT 位置编码的攻击
"""
x.requires_grad = True
# 只攻击位置编码
output = vit_model(x, perturb_positions=True)
loss = output.sum()
model.zero_grad()
loss.backward()
# 位置编码方向的梯度
pos_grad = x.grad[:, :, 0, :].clone() # 假设图像块
return pos_gradNeuroShield-ViT
NeuroShield 是一种针对 ViT 的即插即用防御方法:2
class NeuroShieldViT:
"""
NeuroShield: 中和 ViT 的脆弱神经元
"""
def __init__(self, model, threshold=0.9):
self.model = model
self.threshold = threshold
self.vulnerable_tokens = {}
def identify_vulnerable_neurons(self, x, y):
"""
识别脆弱神经元
"""
x.requires_grad = True
output = self.model(x)
loss = F.cross_entropy(output, y)
self.model.zero_grad()
loss.backward()
# 分析梯度的显著性
grad_magnitude = x.grad.abs()
# 标记梯度超过阈值的 token
vulnerable_mask = grad_magnitude.mean(dim=1) > self.threshold
return vulnerable_mask
def neutralize(self, x, vulnerable_mask):
"""
中和脆弱神经元的激活
"""
# 方法1:零化
x_neutralized = x.clone()
x_neutralized[vulnerable_mask] = 0
# 方法2:替换为注意力平均值
x_neutralized = x.clone()
attention_avg = x[vulnerable_mask].mean(dim=0, keepdim=True)
x_neutralized[vulnerable_mask] = attention_avg
return x_neutralized
def forward(self, x):
"""
带 NeuroShield 的前向传播
"""
# 第一遍:识别脆弱神经元
with torch.no_grad():
x_small = F.interpolate(x, scale_factor=0.5)
output_small = self.model(x_small)
y_pred = output_small.argmax(dim=1)
vulnerable_mask = self.identify_vulnerable_neurons(x, y_pred)
# 第二遍:中和并预测
x_neutralized = self.neutralize(x, vulnerable_mask)
output = self.model(x_neutralized)
return outputNeuroShield 效果
| 指标 | 原始 ViT | NeuroShield-ViT |
|---|---|---|
| Clean Acc | 82.1% | 81.8% |
| PGD Acc | 45.3% | 77.8% |
| AutoAttack | 43.1% | 75.2% |
| 推理开销 | 1x | ~1.1x |
Protego:注意力机制检测
Protego 利用 ViT 注意力机制的内在线索检测对抗样本:3
class ProtegoDetector:
"""
Protego: 基于注意力机制的对抗样本检测
"""
def __init__(self, vit_model, detector=None):
self.model = vit_model
self.detector = detector or self._build_detector()
def _build_detector(self):
"""构建基于注意力统计的检测器"""
return nn.Sequential(
nn.Linear(12, 64), # ViT 有 12 层
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1),
nn.Sigmoid()
)
def extract_attention_features(self, x):
"""
提取注意力特征
"""
attentions = []
def hook_fn(module, input, output):
# 提取注意力矩阵
attn = output[1] # [B, num_heads, N, N]
attentions.append(attn)
hooks = []
for block in self.model.blocks:
hooks.append(block.attn.register_forward_hook(hook_fn))
with torch.no_grad():
_ = self.model(x)
for hook in hooks:
hook.remove()
# 提取统计特征
features = []
for attn in attentions:
# 注意力熵
attn_prob = F.softmax(attn, dim=-1)
entropy = -(attn_prob * torch.log(attn_prob + 1e-10)).sum(dim=-1)
features.extend([entropy.mean(), entropy.std()])
# 注意力集中度
max_attn = attn.max(dim=-1)[0]
features.extend([max_attn.mean(), max_attn.std()])
return torch.tensor(features).view(1, -1)
def detect(self, x):
"""
检测是否为对抗样本
"""
features = self.extract_attention_features(x)
score = self.detector(features)
return score.item() > 0.5, score.item()
def train(self, clean_loader, adv_loader, epochs=50):
"""
训练检测器
"""
optimizer = torch.optim.Adam(self.detector.parameters())
criterion = nn.BCELoss()
for epoch in range(epochs):
# 干净样本标签为 0
for x, _ in clean_loader:
x = x.to(device)
features = self.extract_attention_features(x)
label = torch.zeros(x.size(0), 1).to(device)
optimizer.zero_grad()
pred = self.detector(features)
loss = criterion(pred, label)
loss.backward()
optimizer.step()
# 对抗样本标签为 1
for x, _ in adv_loader:
x = x.to(device)
features = self.extract_attention_features(x)
label = torch.ones(x.size(0), 1).to(device)
optimizer.zero_grad()
pred = self.detector(features)
loss = criterion(pred, label)
loss.backward()
optimizer.step()NeighborViT:对抗补丁防御
NeighborViT 专门针对对抗补丁攻击:4
class NeighborViT:
"""
NeighborViT: 利用邻居补丁信息防御对抗补丁
"""
def __init__(self, model, window_size=3):
self.model = model
self.window_size = window_size
def analyze_patch_neighbors(self, x):
"""
分析每个补丁与邻居的关系
"""
B, C, H, W = x.shape
patch_size = self.model.patch_size
# 分割成补丁
x_patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
x_patches = x_patches.contiguous().view(B, C, -1, patch_size, patch_size)
x_patches = x_patches.permute(0, 2, 1, 3, 4) # [B, num_patches, C, H, W]
# 计算补丁间相似度
num_patches = x_patches.size(1)
# 检测异常补丁
anomaly_scores = []
for i in range(num_patches):
patch = x_patches[:, i]
# 获取邻居
neighbors = self._get_neighbors(x_patches, i)
if len(neighbors) > 0:
neighbor_mean = torch.stack(neighbors).mean(dim=0)
# 计算与邻居的差异
diff = (patch - neighbor_mean).abs().mean(dim=(1, 2, 3))
anomaly_scores.append(diff)
else:
anomaly_scores.append(torch.zeros(B))
return torch.stack(anomaly_scores, dim=1)
def mitigate(self, x, anomaly_scores, threshold=0.5):
"""
缓解对抗补丁的影响
"""
x_patched = x.clone()
# 替换异常补丁
anomaly_mask = anomaly_scores > threshold
# 对每个异常补丁进行修复
for patch_idx in range(anomaly_mask.size(1)):
if anomaly_mask[0, patch_idx]:
# 替换为邻居平均
neighbors = self._get_neighbors(x, patch_idx)
if len(neighbors) > 0:
x_patched[:, patch_idx] = torch.stack(neighbors).mean(dim=0)
return x_patched
def forward(self, x):
"""
NeighborViT 前向传播
"""
anomaly_scores = self.analyze_patch_neighbors(x)
x_mitigated = self.mitigate(x, anomaly_scores)
output = self.model(x_mitigated)
return outputViT 对抗训练
def vit_adversarial_training(vit_model, train_loader, epsilon=4/255, epochs=100):
"""
ViT 对抗训练
"""
optimizer = torch.optim.AdamW(vit_model.parameters(), lr=1e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
for epoch in range(epochs):
vit_model.train()
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# PGD 对抗样本
images_adv = pgd_attack(vit_model, images, labels,
epsilon=epsilon, alpha=epsilon/4, num_iter=7)
optimizer.zero_grad()
output = vit_model(images_adv)
loss = F.cross_entropy(output, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
print(f"Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}")
return vit_model评估基准
ViT 鲁棒性评估套件
def evaluate_vit_robustness(vit_model, test_loader, epsilon=4/255):
"""
全面评估 ViT 鲁棒性
"""
results = {}
# 1. 干净准确率
clean_acc = evaluate_accuracy(vit_model, test_loader)
results['clean'] = clean_acc
# 2. FGSM 攻击
fgsm_acc = evaluate_attack(vit_model, test_loader, fgsm_attack, epsilon)
results['fgsm'] = fgsm_acc
# 3. PGD 攻击
pgd_acc = evaluate_attack(vit_model, test_loader, pgd_attack, epsilon)
results['pgd'] = pgd_acc
# 4. AutoAttack
aa_acc = evaluate_attack(vit_model, test_loader, autoattack, epsilon)
results['autoattack'] = aa_acc
# 5. 迁移攻击(CNN → ViT)
transfer_acc = evaluate_transfer_attack(cnn_model, vit_model, test_loader, epsilon)
results['transfer_cnn_to_vit'] = transfer_acc
# 6. EOT 攻击
eot_acc = evaluate_attack(vit_model, test_loader, eot_attack, epsilon)
results['eot'] = eot_acc
return results常用数据集结果
| 模型 | Clean | PGD | AutoAttack | EOT |
|---|---|---|---|---|
| ResNet-50 | 76.1% | 42.3% | 41.8% | 38.5% |
| DeiT-S | 79.8% | 48.2% | 47.1% | 44.3% |
| ViT-B/16 | 81.2% | 45.3% | 43.9% | 40.1% |
| Swin-T | 83.1% | 52.1% | 50.8% | 48.2% |
相关主题
- adversarial-robustness-fundamentals — 对抗鲁棒性基础
- adversarial-training-methods — 对抗训练方法
- certified-robustness-theory — 认证鲁棒性理论
- swin-transformer — Swin Transformer 架构
- vision-transformer-vit — ViT 基础
参考文献
Footnotes
-
Mao, C., et al. (2025). Adversarial Threats to Vision Transformers: Evaluating Robustness Beyond CNNs. Neural Computing and Applications. https://link.springer.com/article/10.1007/s00521-025-11734-0 ↩ ↩2
-
Chen, L., et al. (2025). NeuroShield-ViT: Protecting Vision Transformers from Adversarial Attacks. arXiv:2502.04679. https://arxiv.org/abs/2502.04679 ↩
-
Liu, Y., et al. (2025). Protego: Detecting Adversarial Examples for Vision Transformers via Intrinsic Capabilities. arXiv:2501.07044. https://arxiv.org/abs/2501.07044 ↩
-
Wang, H., et al. (2025). NeighborViT: Defending Against Adversarial Patches via Neighbor Patch Information. ICLR 2025. https://openreview.net/pdf?id=vOSwtXGSA2 ↩