对抗样本检测与防御

概述

对抗样本检测(Detection)是对抗防御的重要分支,旨在识别并拒绝潜在的对抗输入,而无需修改模型本身。与对抗训练不同,检测方法假设攻击者可能绕过特定防御,因此采用”宁可错杀,不可放过”的策略。

检测方法分类

类别方法原理
统计检测激活统计对抗样本激活模式异常
重构检测输入净化对抗样本难以重构
分类检测二分类器训练专门检测器
不确定性检测置信度对抗样本不确定性高

基于统计的检测

激活统计检测

class ActivationStatisticsDetector:
    """
    基于激活统计的对抗样本检测
    """
    
    def __init__(self, model, layer_name='layer4'):
        self.model = model
        self.layer_name = layer_name
        self.register_hooks()
        
        # 统计量存储
        self.clean_stats = None
        self.fitted = False
    
    def register_hooks(self):
        """注册中间层 hook"""
        def hook_fn(module, input, output):
            self.activations = output.detach()
        
        for name, module in self.model.named_modules():
            if self.layer_name in name:
                module.register_forward_hook(hook_fn)
    
    def compute_statistics(self, x):
        """计算激活统计量"""
        with torch.no_grad():
            _ = self.model(x)
        
        activations = self.activations
        
        # 计算多种统计量
        stats = {
            'mean': activations.mean(dim=(1, 2, 3)),
            'std': activations.std(dim=(1, 2, 3)),
            'max': activations.amax(dim=(1, 2, 3)),
            'min': activations.amin(dim=(1, 2, 3)),
            'sparsity': (activations.abs() < 0.01).float().mean(dim=(1, 2, 3)),
            'kurtosis': self.compute_kurtosis(activations)
        }
        
        return torch.cat([v.unsqueeze(1) for v in stats.values()], dim=1)
    
    def compute_kurtosis(self, x):
        """计算峰度"""
        mean = x.mean(dim=(1, 2, 3), keepdim=True)
        std = x.std(dim=(1, 2, 3), keepdim=True)
        z = (x - mean) / (std + 1e-10)
        kurtosis = (z ** 4).mean(dim=(1, 2, 3)) - 3
        return kurtosis
    
    def fit(self, dataloader, num_samples=1000):
        """在干净样本上拟合统计分布"""
        all_stats = []
        
        for i, (x, _) in enumerate(dataloader):
            if i * x.size(0) > num_samples:
                break
            stats = self.compute_statistics(x)
            all_stats.append(stats)
        
        all_stats = torch.cat(all_stats, dim=0)
        
        # 计算均值和协方差
        self.clean_mean = all_stats.mean(dim=0)
        self.clean_cov = torch.cov(all_stats.T)
        self.fitted = True
    
    def detect(self, x, threshold=3.0):
        """
        检测对抗样本
        
        Returns:
            is_adversarial: bool
            anomaly_score: float
        """
        if not self.fitted:
            raise RuntimeError("Must call fit() first")
        
        stats = self.compute_statistics(x)
        
        # 马氏距离
        diff = stats - self.clean_mean.unsqueeze(0)
        mahalanobis = torch.sqrt(diff @ torch.linalg.inv(self.clean_cov + 1e-6 * torch.eye(self.clean_cov.size(0))) @ diff.T).diagonal()
        
        anomaly_score = mahalanobis.item()
        is_adversarial = anomaly_score > threshold
        
        return is_adversarial, anomaly_score

LID(Local Intrinsic Dimensionality)检测

def compute_lid(model, x, k=20):
    """
    计算局部内在维度(LID)
    
    LID 衡量样本周围邻域的扩张速率
    对抗样本通常具有异常的 LID 值
    """
    batch_size = x.size(0)
    x_flat = x.view(batch_size, -1)
    n = x_flat.size(1)
    
    # 随机采样扰动点
    num_samples = 1000
    perturbations = torch.randn(batch_size, num_samples, n, device=x.device) * 0.1
    perturbations = x_flat.unsqueeze(1) + perturbations
    
    # 计算到所有扰动点的距离
    distances = torch.norm(perturbations - x_flat.unsqueeze(1), dim=2)  # [B, num_samples]
    
    # 找 k 个最近邻
    knn_distances, _ = distances.topk(k, dim=1, largest=False)
    
    # 计算 LID
    log_ratios = torch.log(knn_distances[:, -1:] + 1e-10) - torch.log(knn_distances[:, :k-1] + 1e-10)
    lid = -1.0 / (log_ratios.mean(dim=1) + 1e-10)
    
    return lid
 
 
class LIDDetector:
    """
    基于 LID 的对抗样本检测
    """
    
    def __init__(self, model, k=20):
        self.model = model
        self.k = k
        self.clean_lids = None
    
    def fit(self, dataloader, num_batches=100):
        """拟合干净样本的 LID 分布"""
        lids = []
        
        for i, (x, _) in enumerate(dataloader):
            if i >= num_batches:
                break
            lid = compute_lid(self.model, x, self.k)
            lids.append(lid)
        
        self.clean_lids = torch.cat(lids, dim=0)
        self.lid_threshold = self.clean_lids.mean() + 3 * self.clean_lids.std()
    
    def detect(self, x):
        """检测"""
        lids = compute_lid(self.model, x, self.k)
        return lids > self.lid_threshold, lids

基于重构的检测

降噪自动编码器检测

class DenoisingAutoencoderDetector:
    """
    使用降噪自编码器检测对抗样本
    """
    
    def __init__(self, encoder, decoder):
        self.encoder = encoder
        self.decoder = decoder
    
    def reconstruct(self, x):
        """重构输入"""
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon
    
    def compute_reconstruction_error(self, x):
        """计算重构误差"""
        x_recon = self.reconstruct(x)
        
        # 多种误差度量
        mse = ((x - x_recon) ** 2).mean(dim=(1, 2, 3))
        mae = (x - x_recon).abs().mean(dim=(1, 2, 3))
        perceptual = self.perceptual_distance(x, x_recon)
        
        return {
            'mse': mse,
            'mae': mae,
            'perceptual': perceptual
        }
    
    def perceptual_distance(self, x1, x2):
        """感知距离(简化版)"""
        # 使用预训练 VGG 的特征距离
        return ((x1 - x2) ** 2).mean()
    
    def detect(self, x, threshold=None):
        """检测"""
        errors = self.compute_reconstruction_error(x)
        
        if threshold is None:
            # 基于干净样本的统计确定阈值
            threshold = self.baseline_threshold
        
        is_adv = errors['mse'] > threshold
        return is_adv, errors

对抗净化(Adversarial Purification)

class AdversarialPurification:
    """
    对抗净化:将对抗样本转换为干净样本
    """
    
    def __init__(self, model):
        self.model = model
    
    def purify_smoothing(self, x, num_samples=100):
        """
        随机平滑净化
        """
        predictions = []
        
        for _ in range(num_samples):
            # 添加高斯噪声
            x_noisy = x + torch.randn_like(x) * 0.1
            x_noisy = torch.clamp(x_noisy, 0, 1)
            
            with torch.no_grad():
                pred = self.model(x_noisy).argmax(dim=1)
                predictions.append(pred)
        
        # 多数投票
        predictions = torch.stack(predictions)
        final_pred = predictions.mode(0)[0]
        
        return final_pred
    
    def purify_ensemble(self, x):
        """
        集成净化
        """
        # 应用多种变换
        transforms = [
            lambda x: x,
            lambda x: F.gaussian_blur(x, kernel_size=5, sigma=1.0),
            lambda x: (x + torch.randn_like(x) * 0.05).clamp(0, 1),
            lambda x: F.interpolate(x, scale_factor=0.9),
            lambda x: x ** 1.2,  # gamma 校正
        ]
        
        predictions = []
        for T in transforms:
            x_transformed = T(x)
            with torch.no_grad():
                pred = self.model(x_transformed).argmax(dim=1)
                predictions.append(pred)
        
        predictions = torch.stack(predictions)
        return predictions.mode(0)[0]
    
    def purify_generative(self, x, generator):
        """
        生成式净化
        """
        # 使用生成模型净化
        with torch.no_grad():
            # 编码
            z = generator.encode(x)
            # 重构
            x_purified = generator.decode(z)
        
        return self.model(x_purified).argmax(dim=1)

输入变换防御

组合变换防御

class InputTransformationDefense:
    """
    输入变换防御
    """
    
    def __init__(self, model):
        self.model = model
        self.transforms = self._get_transforms()
    
    def _get_transforms(self):
        """定义变换集合"""
        return [
            ('original', lambda x: x),
            ('jitter', lambda x: x + torch.randn_like(x) * 0.01),
            ('blur', lambda x: F.avg_pool2d(F.pad(x, (2,2,2,2), mode='replicate'), 5, stride=1)),
            ('resize', lambda x: F.interpolate(F.interpolate(x, scale_factor=0.9), scale_factor=1/0.9)),
            ('flip', lambda x: torch.flip(x, dims=[3])),
            ('crop', lambda x: F.interpolate(F.interpolate(x, size=(200, 200)), size=(224, 224))),
        ]
    
    def predict(self, x, ensemble='all'):
        """
        集成预测
        
        Args:
            x: 输入
            ensemble: 'all' 或 ('jitter', 'blur')
        """
        if ensemble == 'all':
            transform_names = [t[0] for t in self.transforms]
        else:
            transform_names = list(ensemble)
        
        predictions = []
        for name, transform in self.transforms:
            if name in transform_names:
                x_transformed = transform(x)
                with torch.no_grad():
                    pred = self.model(x_transformed)
                    predictions.append(pred)
        
        # 平均或投票
        predictions = torch.stack(predictions)
        avg_pred = predictions.mean(dim=0)
        
        return avg_pred.argmax(dim=1), avg_pred
    
    def certify(self, x, epsilon, num_transforms=100):
        """
        认证:对随机变换后的预测一致性进行认证
        """
        correct_count = 0
        
        for _ in range(num_transforms):
            # 随机组合变换
            transform = random.choice(self.transforms)[1]
            x_transformed = transform(x)
            
            # 扰动
            delta = torch.randn_like(x) * epsilon
            x_transformed = (x_transformed + delta).clamp(0, 1)
            
            with torch.no_grad():
                pred = self.model(x_transformed).argmax()
            
            if pred == self.model(x).argmax():
                correct_count += 1
        
        certified = correct_count == num_transforms
        return certified, correct_count / num_transforms

JPEG 压缩防御

def jpeg_compress(x, quality=75):
    """
    JPEG 压缩防御
    """
    import torchvision.transforms.functional as TF
    
    # 简化版:使用平均池化模拟 JPEG 压缩效应
    # 实际应使用 libjpeg 或 pillow
    x_compressed = F.avg_pool2d(x, kernel_size=2, stride=2)
    x_compressed = F.interpolate(x_compressed, scale_factor=2, mode='nearest')
    
    return x_compressed
 
 
class JpegDefense:
    """
    多质量 JPEG 压缩防御
    """
    
    def __init__(self, model, qualities=[50, 70, 90]):
        self.model = model
        self.qualities = qualities
    
    def predict(self, x):
        """多质量平均预测"""
        predictions = []
        
        for q in self.qualities:
            x_compressed = self.jpeg_transform(x, q)
            with torch.no_grad():
                pred = self.model(x_compressed)
                predictions.append(pred)
        
        avg_pred = torch.stack(predictions).mean(dim=0)
        return avg_pred.argmax(dim=1)
    
    def jpeg_transform(self, x, quality):
        """JPEG 变换"""
        # 实现细节省略
        # 实际应调用 libjpeg
        return x  # 占位符

综合防御策略

多层防御架构

class MultiLayerDefense:
    """
    多层防御架构
    """
    
    def __init__(self, model):
        self.model = model
        self.detector = ActivationStatisticsDetector(model)
        self.purifier = AdversarialPurification(model)
        self.transform_defense = InputTransformationDefense(model)
        
        # 加载干净样本拟合
        # self.detector.fit(clean_loader)
    
    def predict(self, x):
        """
        综合预测流程
        
        1. 检测是否为对抗样本
        2. 如果是,进行净化
        3. 应用输入变换
        4. 集成预测
        """
        # Step 1: 检测
        is_adv, score = self.detector.detect(x)
        
        if is_adv:
            # Step 2: 净化
            x = self.purifier.purify_ensemble(x)
        
        # Step 3: 输入变换
        pred, logits = self.transform_defense.predict(x)
        
        return {
            'prediction': pred,
            'is_adversarial': is_adv,
            'anomaly_score': score,
            'confidence': F.softmax(logits, dim=1).max().item()
        }
    
    def train(self, clean_loader, adv_loader):
        """
        训练检测器
        """
        # 在干净样本上拟合
        self.detector.fit(clean_loader)
        
        # 可选:训练额外的分类检测器
        # ...

检测方法对比

方法优点缺点攻击绕过难度
激活统计简单有效需要干净样本
LID理论扎实计算开销大
重构误差直观需要额外模型
MC Dropout不确定性量化需要多次前向
集成变换鲁棒推理开销大

实践建议

检测器训练流程

def train_detection_pipeline(model, clean_loader, adv_loader, 
                              detector_type='statistical'):
    """
    完整的检测器训练流程
    """
    if detector_type == 'statistical':
        detector = ActivationStatisticsDetector(model)
        detector.fit(clean_loader)
        return detector
    
    elif detector_type == 'lid':
        detector = LIDDetector(model)
        detector.fit(clean_loader)
        return detector
    
    elif detector_type == 'classifier':
        # 训练二分类器
        detector = BinaryClassifierDetector(model)
        
        # 收集特征
        clean_features = []
        for x, _ in clean_loader:
            feat = detector.extract_features(x)
            clean_features.append(feat)
        clean_features = torch.cat(clean_features)
        
        adv_features = []
        for x, _ in adv_loader:
            x_adv = pgd_attack(model, x, y)
            feat = detector.extract_features(x_adv)
            adv_features.append(feat)
        adv_features = torch.cat(adv_features)
        
        # 训练分类器
        X = torch.cat([clean_features, adv_features])
        y = torch.cat([torch.zeros(len(clean_features)), 
                      torch.ones(len(adv_features))])
        
        detector.train(X, y)
        return detector
 
 
def evaluate_detection(detector, test_clean, test_adv):
    """
    评估检测器性能
    """
    # 干净样本检测(假阳性率)
    clean_detected = 0
    for x in test_clean:
        is_adv, _ = detector.detect(x)
        if is_adv:
            clean_detected += 1
    fpr = clean_detected / len(test_clean)
    
    # 对抗样本检测(真阳性率)
    adv_detected = 0
    for x in test_adv:
        is_adv, _ = detector.detect(x)
        if is_adv:
            adv_detected += 1
    tpr = adv_detected / len(test_adv)
    
    return {
        'fpr': fpr,
        'tpr': tpr,
        'detection_rate': tpr,
        'false_alarm_rate': fpr
    }

相关主题


参考文献