对抗样本检测与防御
概述
对抗样本检测(Detection)是对抗防御的重要分支,旨在识别并拒绝潜在的对抗输入,而无需修改模型本身。与对抗训练不同,检测方法假设攻击者可能绕过特定防御,因此采用”宁可错杀,不可放过”的策略。
检测方法分类
| 类别 | 方法 | 原理 |
|---|
| 统计检测 | 激活统计 | 对抗样本激活模式异常 |
| 重构检测 | 输入净化 | 对抗样本难以重构 |
| 分类检测 | 二分类器 | 训练专门检测器 |
| 不确定性检测 | 置信度 | 对抗样本不确定性高 |
基于统计的检测
激活统计检测
class ActivationStatisticsDetector:
"""
基于激活统计的对抗样本检测
"""
def __init__(self, model, layer_name='layer4'):
self.model = model
self.layer_name = layer_name
self.register_hooks()
# 统计量存储
self.clean_stats = None
self.fitted = False
def register_hooks(self):
"""注册中间层 hook"""
def hook_fn(module, input, output):
self.activations = output.detach()
for name, module in self.model.named_modules():
if self.layer_name in name:
module.register_forward_hook(hook_fn)
def compute_statistics(self, x):
"""计算激活统计量"""
with torch.no_grad():
_ = self.model(x)
activations = self.activations
# 计算多种统计量
stats = {
'mean': activations.mean(dim=(1, 2, 3)),
'std': activations.std(dim=(1, 2, 3)),
'max': activations.amax(dim=(1, 2, 3)),
'min': activations.amin(dim=(1, 2, 3)),
'sparsity': (activations.abs() < 0.01).float().mean(dim=(1, 2, 3)),
'kurtosis': self.compute_kurtosis(activations)
}
return torch.cat([v.unsqueeze(1) for v in stats.values()], dim=1)
def compute_kurtosis(self, x):
"""计算峰度"""
mean = x.mean(dim=(1, 2, 3), keepdim=True)
std = x.std(dim=(1, 2, 3), keepdim=True)
z = (x - mean) / (std + 1e-10)
kurtosis = (z ** 4).mean(dim=(1, 2, 3)) - 3
return kurtosis
def fit(self, dataloader, num_samples=1000):
"""在干净样本上拟合统计分布"""
all_stats = []
for i, (x, _) in enumerate(dataloader):
if i * x.size(0) > num_samples:
break
stats = self.compute_statistics(x)
all_stats.append(stats)
all_stats = torch.cat(all_stats, dim=0)
# 计算均值和协方差
self.clean_mean = all_stats.mean(dim=0)
self.clean_cov = torch.cov(all_stats.T)
self.fitted = True
def detect(self, x, threshold=3.0):
"""
检测对抗样本
Returns:
is_adversarial: bool
anomaly_score: float
"""
if not self.fitted:
raise RuntimeError("Must call fit() first")
stats = self.compute_statistics(x)
# 马氏距离
diff = stats - self.clean_mean.unsqueeze(0)
mahalanobis = torch.sqrt(diff @ torch.linalg.inv(self.clean_cov + 1e-6 * torch.eye(self.clean_cov.size(0))) @ diff.T).diagonal()
anomaly_score = mahalanobis.item()
is_adversarial = anomaly_score > threshold
return is_adversarial, anomaly_score
LID(Local Intrinsic Dimensionality)检测
def compute_lid(model, x, k=20):
"""
计算局部内在维度(LID)
LID 衡量样本周围邻域的扩张速率
对抗样本通常具有异常的 LID 值
"""
batch_size = x.size(0)
x_flat = x.view(batch_size, -1)
n = x_flat.size(1)
# 随机采样扰动点
num_samples = 1000
perturbations = torch.randn(batch_size, num_samples, n, device=x.device) * 0.1
perturbations = x_flat.unsqueeze(1) + perturbations
# 计算到所有扰动点的距离
distances = torch.norm(perturbations - x_flat.unsqueeze(1), dim=2) # [B, num_samples]
# 找 k 个最近邻
knn_distances, _ = distances.topk(k, dim=1, largest=False)
# 计算 LID
log_ratios = torch.log(knn_distances[:, -1:] + 1e-10) - torch.log(knn_distances[:, :k-1] + 1e-10)
lid = -1.0 / (log_ratios.mean(dim=1) + 1e-10)
return lid
class LIDDetector:
"""
基于 LID 的对抗样本检测
"""
def __init__(self, model, k=20):
self.model = model
self.k = k
self.clean_lids = None
def fit(self, dataloader, num_batches=100):
"""拟合干净样本的 LID 分布"""
lids = []
for i, (x, _) in enumerate(dataloader):
if i >= num_batches:
break
lid = compute_lid(self.model, x, self.k)
lids.append(lid)
self.clean_lids = torch.cat(lids, dim=0)
self.lid_threshold = self.clean_lids.mean() + 3 * self.clean_lids.std()
def detect(self, x):
"""检测"""
lids = compute_lid(self.model, x, self.k)
return lids > self.lid_threshold, lids
基于重构的检测
降噪自动编码器检测
class DenoisingAutoencoderDetector:
"""
使用降噪自编码器检测对抗样本
"""
def __init__(self, encoder, decoder):
self.encoder = encoder
self.decoder = decoder
def reconstruct(self, x):
"""重构输入"""
z = self.encoder(x)
x_recon = self.decoder(z)
return x_recon
def compute_reconstruction_error(self, x):
"""计算重构误差"""
x_recon = self.reconstruct(x)
# 多种误差度量
mse = ((x - x_recon) ** 2).mean(dim=(1, 2, 3))
mae = (x - x_recon).abs().mean(dim=(1, 2, 3))
perceptual = self.perceptual_distance(x, x_recon)
return {
'mse': mse,
'mae': mae,
'perceptual': perceptual
}
def perceptual_distance(self, x1, x2):
"""感知距离(简化版)"""
# 使用预训练 VGG 的特征距离
return ((x1 - x2) ** 2).mean()
def detect(self, x, threshold=None):
"""检测"""
errors = self.compute_reconstruction_error(x)
if threshold is None:
# 基于干净样本的统计确定阈值
threshold = self.baseline_threshold
is_adv = errors['mse'] > threshold
return is_adv, errors
对抗净化(Adversarial Purification)
class AdversarialPurification:
"""
对抗净化:将对抗样本转换为干净样本
"""
def __init__(self, model):
self.model = model
def purify_smoothing(self, x, num_samples=100):
"""
随机平滑净化
"""
predictions = []
for _ in range(num_samples):
# 添加高斯噪声
x_noisy = x + torch.randn_like(x) * 0.1
x_noisy = torch.clamp(x_noisy, 0, 1)
with torch.no_grad():
pred = self.model(x_noisy).argmax(dim=1)
predictions.append(pred)
# 多数投票
predictions = torch.stack(predictions)
final_pred = predictions.mode(0)[0]
return final_pred
def purify_ensemble(self, x):
"""
集成净化
"""
# 应用多种变换
transforms = [
lambda x: x,
lambda x: F.gaussian_blur(x, kernel_size=5, sigma=1.0),
lambda x: (x + torch.randn_like(x) * 0.05).clamp(0, 1),
lambda x: F.interpolate(x, scale_factor=0.9),
lambda x: x ** 1.2, # gamma 校正
]
predictions = []
for T in transforms:
x_transformed = T(x)
with torch.no_grad():
pred = self.model(x_transformed).argmax(dim=1)
predictions.append(pred)
predictions = torch.stack(predictions)
return predictions.mode(0)[0]
def purify_generative(self, x, generator):
"""
生成式净化
"""
# 使用生成模型净化
with torch.no_grad():
# 编码
z = generator.encode(x)
# 重构
x_purified = generator.decode(z)
return self.model(x_purified).argmax(dim=1)
输入变换防御
组合变换防御
class InputTransformationDefense:
"""
输入变换防御
"""
def __init__(self, model):
self.model = model
self.transforms = self._get_transforms()
def _get_transforms(self):
"""定义变换集合"""
return [
('original', lambda x: x),
('jitter', lambda x: x + torch.randn_like(x) * 0.01),
('blur', lambda x: F.avg_pool2d(F.pad(x, (2,2,2,2), mode='replicate'), 5, stride=1)),
('resize', lambda x: F.interpolate(F.interpolate(x, scale_factor=0.9), scale_factor=1/0.9)),
('flip', lambda x: torch.flip(x, dims=[3])),
('crop', lambda x: F.interpolate(F.interpolate(x, size=(200, 200)), size=(224, 224))),
]
def predict(self, x, ensemble='all'):
"""
集成预测
Args:
x: 输入
ensemble: 'all' 或 ('jitter', 'blur')
"""
if ensemble == 'all':
transform_names = [t[0] for t in self.transforms]
else:
transform_names = list(ensemble)
predictions = []
for name, transform in self.transforms:
if name in transform_names:
x_transformed = transform(x)
with torch.no_grad():
pred = self.model(x_transformed)
predictions.append(pred)
# 平均或投票
predictions = torch.stack(predictions)
avg_pred = predictions.mean(dim=0)
return avg_pred.argmax(dim=1), avg_pred
def certify(self, x, epsilon, num_transforms=100):
"""
认证:对随机变换后的预测一致性进行认证
"""
correct_count = 0
for _ in range(num_transforms):
# 随机组合变换
transform = random.choice(self.transforms)[1]
x_transformed = transform(x)
# 扰动
delta = torch.randn_like(x) * epsilon
x_transformed = (x_transformed + delta).clamp(0, 1)
with torch.no_grad():
pred = self.model(x_transformed).argmax()
if pred == self.model(x).argmax():
correct_count += 1
certified = correct_count == num_transforms
return certified, correct_count / num_transforms
JPEG 压缩防御
def jpeg_compress(x, quality=75):
"""
JPEG 压缩防御
"""
import torchvision.transforms.functional as TF
# 简化版:使用平均池化模拟 JPEG 压缩效应
# 实际应使用 libjpeg 或 pillow
x_compressed = F.avg_pool2d(x, kernel_size=2, stride=2)
x_compressed = F.interpolate(x_compressed, scale_factor=2, mode='nearest')
return x_compressed
class JpegDefense:
"""
多质量 JPEG 压缩防御
"""
def __init__(self, model, qualities=[50, 70, 90]):
self.model = model
self.qualities = qualities
def predict(self, x):
"""多质量平均预测"""
predictions = []
for q in self.qualities:
x_compressed = self.jpeg_transform(x, q)
with torch.no_grad():
pred = self.model(x_compressed)
predictions.append(pred)
avg_pred = torch.stack(predictions).mean(dim=0)
return avg_pred.argmax(dim=1)
def jpeg_transform(self, x, quality):
"""JPEG 变换"""
# 实现细节省略
# 实际应调用 libjpeg
return x # 占位符
综合防御策略
多层防御架构
class MultiLayerDefense:
"""
多层防御架构
"""
def __init__(self, model):
self.model = model
self.detector = ActivationStatisticsDetector(model)
self.purifier = AdversarialPurification(model)
self.transform_defense = InputTransformationDefense(model)
# 加载干净样本拟合
# self.detector.fit(clean_loader)
def predict(self, x):
"""
综合预测流程
1. 检测是否为对抗样本
2. 如果是,进行净化
3. 应用输入变换
4. 集成预测
"""
# Step 1: 检测
is_adv, score = self.detector.detect(x)
if is_adv:
# Step 2: 净化
x = self.purifier.purify_ensemble(x)
# Step 3: 输入变换
pred, logits = self.transform_defense.predict(x)
return {
'prediction': pred,
'is_adversarial': is_adv,
'anomaly_score': score,
'confidence': F.softmax(logits, dim=1).max().item()
}
def train(self, clean_loader, adv_loader):
"""
训练检测器
"""
# 在干净样本上拟合
self.detector.fit(clean_loader)
# 可选:训练额外的分类检测器
# ...
检测方法对比
| 方法 | 优点 | 缺点 | 攻击绕过难度 |
|---|
| 激活统计 | 简单有效 | 需要干净样本 | 中 |
| LID | 理论扎实 | 计算开销大 | 高 |
| 重构误差 | 直观 | 需要额外模型 | 中 |
| MC Dropout | 不确定性量化 | 需要多次前向 | 低 |
| 集成变换 | 鲁棒 | 推理开销大 | 高 |
实践建议
检测器训练流程
def train_detection_pipeline(model, clean_loader, adv_loader,
detector_type='statistical'):
"""
完整的检测器训练流程
"""
if detector_type == 'statistical':
detector = ActivationStatisticsDetector(model)
detector.fit(clean_loader)
return detector
elif detector_type == 'lid':
detector = LIDDetector(model)
detector.fit(clean_loader)
return detector
elif detector_type == 'classifier':
# 训练二分类器
detector = BinaryClassifierDetector(model)
# 收集特征
clean_features = []
for x, _ in clean_loader:
feat = detector.extract_features(x)
clean_features.append(feat)
clean_features = torch.cat(clean_features)
adv_features = []
for x, _ in adv_loader:
x_adv = pgd_attack(model, x, y)
feat = detector.extract_features(x_adv)
adv_features.append(feat)
adv_features = torch.cat(adv_features)
# 训练分类器
X = torch.cat([clean_features, adv_features])
y = torch.cat([torch.zeros(len(clean_features)),
torch.ones(len(adv_features))])
detector.train(X, y)
return detector
def evaluate_detection(detector, test_clean, test_adv):
"""
评估检测器性能
"""
# 干净样本检测(假阳性率)
clean_detected = 0
for x in test_clean:
is_adv, _ = detector.detect(x)
if is_adv:
clean_detected += 1
fpr = clean_detected / len(test_clean)
# 对抗样本检测(真阳性率)
adv_detected = 0
for x in test_adv:
is_adv, _ = detector.detect(x)
if is_adv:
adv_detected += 1
tpr = adv_detected / len(test_adv)
return {
'fpr': fpr,
'tpr': tpr,
'detection_rate': tpr,
'false_alarm_rate': fpr
}
相关主题
参考文献