对抗训练方法
概述
对抗训练(Adversarial Training)是目前最有效的对抗防御方法之一。其核心思想是在训练过程中使用对抗样本作为数据增强,使模型学习对对抗扰动的鲁棒性。对抗训练将标准学习问题转化为min-max鲁棒优化问题。1
Min-Max 优化框架
- 内层最大化:找到对当前模型最有效的对抗扰动
- 外层最小化:调整模型参数以最小化最坏情况损失
梯度下降交替求解
def adversarial_training_loop(model, dataloader, optimizer, epsilon, alpha, num_iter):
"""
基础对抗训练循环
"""
model.train()
total_loss = 0
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
# 内层:生成对抗样本
x_adv = images.clone()
x_adv.requires_grad = True
output = model(x_adv)
loss = F.cross_entropy(output, labels)
model.zero_grad()
loss.backward()
# PGD 更新
with torch.no_grad():
grad = x_adv.grad
x_adv = images + torch.clamp(x_adv + alpha * grad.sign() - images, -epsilon, epsilon)
x_adv = torch.clamp(x_adv, 0, 1)
# 外层:更新模型参数
optimizer.zero_grad()
output = model(x_adv)
loss = F.cross_entropy(output, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)PGD-AT(Madry’s Adversarial Training)
算法描述
Madry 等人提出的 PGD-AT 是对抗训练的基准方法:1
def madry_adversarial_training(model, dataloader, epsilon=8/255, alpha=2/255,
num_iter=10, epochs=200):
"""
Madry's PGD Adversarial Training
"""
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
for epoch in range(epochs):
model.train()
total_loss = 0
correct = 0
total = 0
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
# 生成 PGD 对抗样本
x_adv = pgd_attack(model, images, labels, epsilon, alpha, num_iter)
# 标准训练步骤
optimizer.zero_grad()
output = model(x_adv)
loss = F.cross_entropy(output, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = output.max(1)
correct += predicted.eq(labels).sum().item()
total += labels.size(0)
scheduler.step()
print(f"Epoch {epoch}: Loss={total_loss/len(dataloader):.4f}, "
f"Clean Acc={100*correct/total:.2f}%")训练策略
| 阶段 | 学习率 | Epochs | |
|---|---|---|---|
| 预热 | 0.01 | 5 | 0 |
| 正常 | 0.1 | 100 | 8/255 |
| 衰减 | 0.01 | 50 | 8/255 |
| 微调 | 0.001 | 50 | 8/255 |
TRADES(TRADE-off Robust Accuracy and DEfense)
TRADES 通过引入 KL 散度正则化,同时优化干净准确率和鲁棒性:2
def trades_loss(model, x, y, beta=6.0, epsilon=8/255):
"""
TRADES Loss
Args:
beta: 权衡参数,越大越注重鲁棒性
"""
# 干净样本 logits
model.eval()
with torch.no_grad():
logits_clean = model(x)
# 生成对抗样本
x_adv = pgd_attack(model, x, y, epsilon)
# 计算 TRADES 损失
logits_adv = model(x_adv)
ce_loss = F.cross_entropy(logits_clean, y)
kl_loss = F.kl_div(
F.log_softmax(logits_adv, dim=1),
F.softmax(logits_clean, dim=1),
reduction='batchmean'
)
return ce_loss + beta * kl_loss
def trades_training(model, dataloader, optimizer, beta=6.0, epsilon=8/255):
"""
TRADES 对抗训练
"""
model.train()
total_loss = 0
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
loss = trades_loss(model, images, labels, beta, epsilon)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)MART(Modular Adversarial Training)
MART 明确分离了干净样本和对抗样本的损失:3
def mart_loss(model, x, y, epsilon=8/255):
"""
MART Loss
关键洞察:对抗样本的损失应该只来自对抗扰动部分
"""
# 生成对抗样本
x_adv = pgd_attack(model, x, y, epsilon)
# 干净样本 logits
logits_clean = model(x)
# 对抗样本 logits
logits_adv = model(x_adv)
# 计算概率
prob_clean = F.softmax(logits_clean, dim=1)
prob_adv = F.softmax(logits_adv, dim=1)
#MART 损失
kl_loss = F.kl_div(
F.log_softmax(logits_adv, dim=1),
prob_clean,
reduction='none'
).sum(dim=1)
# 加权交叉熵
ce_loss = F.cross_entropy(logits_clean, y, reduction='none')
# 结合
loss = ce_loss + kl_loss - ce_loss * kl_loss
return loss.mean()GRACE(几何感知鲁棒微调)
GRACE 是一种针对 Vision-Language Models 的几何感知对抗训练方法:4
def grace_loss(model, x, y, epsilon=8/255, lambda_geo=0.1):
"""
GRACE Loss
几何失败感知:优化损失曲率和特征流形对齐
"""
# 基础对抗损失
x_adv = pgd_attack(model, x, y, epsilon)
adv_loss = F.cross_entropy(model(x_adv), y)
# 几何正则化
x_adv.requires_grad = True
logits = model(x_adv)
loss_geo = F.cross_entropy(logits, y)
model.zero_grad()
loss_geo.backward(retain_graph=True)
# Hessian 近似:梯度二阶导数
grads = x_adv.grad
hessian_trace = (grads ** 2).sum()
# 特征对齐正则化
with torch.no_grad():
feat_clean = model.encode(x)
feat_adv = model.encode(x_adv)
align_loss = ((feat_clean - feat_adv).norm(dim=1) ** 2).mean()
# 组合损失
total_loss = adv_loss + lambda_geo * (hessian_trace + align_loss)
return total_loss
class GRACETrainer:
def __init__(self, model, epsilon=8/255, lambda_geo=0.1):
self.model = model
self.epsilon = epsilon
self.lambda_geo = lambda_geo
def train_step(self, images, labels):
loss = grace_loss(self.model, images, labels,
self.epsilon, self.lambda_geo)
loss.backward()
return loss.item()CURE(统一认证鲁棒训练)
CURE 提出多范数统一认证鲁棒训练:5
class CURETraining:
"""
Certified Unified Robust Training (CURE)
同时优化 L1, L2, Linf 范数约束下的鲁棒性
"""
def __init__(self, model, epsilon_dict={'linf': 8/255, 'l2': 128/255, 'l1': 12}):
self.model = model
self.epsilon_dict = epsilon_dict
def compute_certified_bounds(self, x, y):
"""计算三种范数的认证边界"""
bounds = {}
for norm, eps in self.epsilon_dict.items():
x_adv = self.generate_robust_perturbation(x, norm, eps)
logits = self.model(x_adv)
# 计算认证边界
if norm == 'linf':
bound = self.compute_ibp_bound(x, eps)
elif norm == 'l2':
bound = self.compute_randomized_bound(x, eps)
else:
bound = self.compute_l1_bound(x, eps)
bounds[norm] = bound
return bounds
def cure_loss(self, x, y):
"""CURE 组合损失"""
bounds = self.compute_certified_bounds(x, y)
total_loss = 0
for norm, bound in bounds.items():
# 使用边界加权
weight = 1.0 / (bound + 1e-6)
x_adv = self.generate_robust_perturbation(x, norm,
self.epsilon_dict[norm])
loss = weight * F.cross_entropy(self.model(x_adv), y)
total_loss += loss
return total_lossS2O(Second-Order Statistics Enhanced AT)
S2O 利用权重的二阶统计量增强对抗训练:6
def s2o_training(model, dataloader, epsilon=8/255):
"""
S2O: Second-Order Statistics Enhanced Adversarial Training
"""
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
for images, labels in dataloader:
# 生成对抗样本
x_adv = pgd_attack(model, images, labels, epsilon)
# 计算权重二阶统计量
weight_stats = self.compute_weight_statistics(model)
optimizer.zero_grad()
# 标准对抗损失
logits = model(x_adv)
adv_loss = F.cross_entropy(logits, labels)
# 二阶统计量正则化
reg_loss = self.compute_s2o_regularizer(weight_stats)
loss = adv_loss + 0.01 * reg_loss
loss.backward()
optimizer.step()
return loss.item()
def compute_weight_statistics(self, model):
"""计算权重二阶统计量"""
stats = {}
for name, param in model.named_parameters():
if 'weight' in name:
# 梯度方差
if param.grad is not None:
stats[f'{name}_grad_var'] = param.grad.var()
# 权重方差
stats[f'{name}_weight_var'] = param.data.var()
return stats训练技巧
学习率调度
def adversarial_training_scheduler(optimizer, epoch):
"""对抗训练专用学习率调度"""
if epoch < 100:
return 0.1
elif epoch < 150:
return 0.01
else:
return 0.001Early Stopping
def evaluate_robustness(model, test_loader, epsilon=8/255):
"""评估鲁棒性"""
model.eval()
robust_correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
x_adv = pgd_attack(model, images, labels, epsilon)
preds = model(x_adv).argmax(dim=1)
robust_correct += preds.eq(labels).sum().item()
total += labels.size(0)
return robust_correct / total模型选择
- 不要基于干净准确率选择模型
- 使用对抗准确率或认证准确率作为指标
- 验证集上选择最优超参数
对抗训练对比
| 方法 | 干净准确率 | 鲁棒准确率 | 计算开销 |
|---|---|---|---|
| Clean | 高 | 低 | 低 |
| PGD-AT | 中 | 高 | 中 |
| TRADES | 中高 | 高 | 中 |
| MART | 中 | 高 | 中 |
| GRACE | 中 | 很高 | 高 |
| CURE | 中 | 高 | 高 |
实践指南
超参数推荐
# MNIST
MNIST_CONFIG = {
'epsilon': 0.3,
'alpha': 0.01,
'num_iter': 40,
'epochs': 200,
'lr': 0.1
}
# CIFAR-10
CIFAR10_CONFIG = {
'epsilon': 8/255,
'alpha': 2/255,
'num_iter': 10,
'epochs': 200,
'lr': 0.1,
'weight_decay': 5e-4
}
# ImageNet
IMAGENET_CONFIG = {
'epsilon': 4/255,
'alpha': 1/255,
'num_iter': 7,
'epochs': 90,
'lr': 0.1,
'weight_decay': 1e-4
}常见问题与解决方案
| 问题 | 原因 | 解决方案 |
|---|---|---|
| 收敛慢 | 学习率过高 | 降低学习率,使用warmup |
| 干净准确率下降 | 对抗训练副作用 | 使用 TRADES 平衡 |
| 过拟合 | 数据不足 | 增加正则化,数据增强 |
| 训练不稳定 | 梯度爆炸 | 梯度裁剪,BN 冻结 |
相关主题
- adversarial-robustness-fundamentals — 对抗鲁棒性基础
- adversarial-attack-methods — 对抗攻击方法
- certified-robustness-theory — 认证鲁棒性理论
- viability-adversarial-robustness-vit — ViT对抗鲁棒性
参考文献
Footnotes
-
Madry, A., et al. (2018). Towards Deep Learning Models Resistant to Adversarial Attacks. ICLR 2018. https://arxiv.org/abs/1706.06083 ↩ ↩2
-
Zhang, H., et al. (2019). Theoretically Principled Trade-off between Robustness and Accuracy. ICML 2019. https://arxiv.org/abs/1901.10538 ↩
-
Wang, Y., et al. (2019). Improving Adversarial Robustness Requires Revisiting Misclassified Examples. ICLR 2019. https://arxiv.org/abs/1904.08554 ↩
-
Chen, Y., et al. (2026). GRACE: Geometric Failures-aware Robust Finetuning of Vision-Language Models. arXiv:2603.27139. https://arxiv.org/abs/2603.27139 ↩
-
Jiang, H., et al. (2025). CURE: Certified Robustness under Multiple Norms. arXiv:2410.03000. https://arxiv.org/abs/2410.03000 ↩
-
Chen, Y., et al. (2026). S2O: Enhancing Adversarial Training with Second-Order Statistics of Weights. arXiv:2603.01264. https://arxiv.org/abs/2603.01264 ↩