相关深入内容:

概述

扩散模型的对抗训练(Adversarial Training)是增强模型鲁棒性的核心方法之一。与判别模型的对抗训练不同,扩散模型的对抗训练需要考虑生成过程的多步特性随机性。本章详细介绍扩散模型对抗训练的理论基础和实践方法。1


1. 对抗训练理论基础

1.1 经典对抗训练回顾

对于分类模型,对抗训练的Min-Max优化框架为:

其中:

  • 是模型参数
  • 是输入, 是标签
  • 是损失函数
  • 是扰动上界

1.2 扩散模型对抗训练的挑战

将对抗训练扩展到扩散模型面临以下挑战:

挑战描述影响
生成目标不像分类有明确标签需要重新定义攻击目标
多步过程涉及T步去噪梯度累积和爆炸风险
随机性生成结果具有随机性难以评估对抗成功与否
计算成本每次攻击需要多次前向传播训练成本极高

1.3 扩散模型对抗训练目标

针对扩散模型,对抗训练的目标是:

其中 是对抗损失,需要根据攻击目标定义:

  • 语义保持 应有相似的语义
  • 条件一致:生成结果应符合条件
  • 生成质量:生成结果应保持高质量

2. 扩散模型对抗训练框架

2.1 AdvDM:首个扩散模型对抗训练方法

AdvDM(Adversarial Diffusion Models)是首个专门针对扩散模型的对抗训练方法。

核心思想:利用扩散模型的噪声预测特性,在训练过程中同时优化对抗扰动和模型参数。

class AdvDM:
    """
    AdvDM: 对抗扩散模型训练框架
    """
    
    def __init__(self, diffusion_model, epsilon=8/255, alpha=1/255, 
                 num_attack_steps=5):
        self.model = diffusion_model
        self.epsilon = epsilon
        self.alpha = alpha
        self.num_attack_steps = num_attack_steps
        
    def train_step(self, clean_images, conditions):
        """
        单步训练
        
        Args:
            clean_images: 干净图像
            conditions: 条件信息(文本、类别等)
        """
        # Step 1: 生成对抗图像
        adversarial_images = self.generate_adversarial_images(
            clean_images, conditions
        )
        
        # Step 2: 计算对抗损失
        loss = self.compute_adversarial_loss(
            clean_images, adversarial_images, conditions
        )
        
        # Step 3: 反向传播并更新
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    def generate_adversarial_images(self, images, conditions):
        """
        生成对抗图像
        """
        # 初始化扰动
        delta = torch.zeros_like(images).uniform_(-self.epsilon, self.epsilon)
        delta.requires_grad_(True)
        
        optimizer = torch.optim.Adam([delta], lr=self.alpha)
        
        for step in range(self.num_attack_steps):
            optimizer.zero_grad()
            
            # 随机选择时间步
            t = torch.randint(0, self.model.num_timesteps, (images.shape[0],))
            
            # 添加噪声
            noise = torch.randn_like(images)
            noisy_images = self.model.add_noise(images + delta, t, noise)
            
            # 预测噪声
            predicted_noise = self.model.noise_predictor(
                noisy_images, t, conditions
            )
            
            # 计算攻击损失
            # 目标:最大化与原始图像的差异,同时保持扰动小
            attack_loss = -self.compute_similarity_loss(
                images + delta, conditions, self.model
            )
            
            # 扰动正则化
            attack_loss += 0.01 * torch.norm(delta, p=2)
            
            attack_loss.backward()
            optimizer.step()
            
            # 投影到epsilon球
            with torch.no_grad():
                delta.copy_(torch.clamp(delta, -self.epsilon, self.epsilon))
        
        return images + delta.detach()
    
    def compute_adversarial_loss(self, clean, adversarial, conditions):
        """
        计算对抗训练损失
        """
        # 1. 干净样本的重建损失
        clean_loss = self.compute_reconstruction_loss(clean, conditions)
        
        # 2. 对抗样本的重建损失
        adv_loss = self.compute_reconstruction_loss(adversarial, conditions)
        
        # 3. 对抗样本的语义保持损失
        semantic_loss = self.compute_semantic_loss(adversarial, clean)
        
        # 综合损失
        total_loss = clean_loss + adv_loss + 0.1 * semantic_loss
        
        return total_loss
    
    def compute_similarity_loss(self, images, conditions, model):
        """
        计算图像与条件的相似度损失
        """
        # 使用CLIP计算相似度
        with torch.no_grad():
            image_features = model.extract_image_features(images)
            text_features = model.extract_text_features(conditions)
            
            similarity = torch.cosine_similarity(
                image_features, text_features
            ).mean()
        
        return similarity

2.2 DiffATR:针对条件扩散模型的对抗训练

DiffATR(Diffusion Adversarial Training with Routing)是针对条件扩散模型的对抗训练方法。

核心改进

  1. 条件感知扰动生成
  2. 自适应攻击强度
  3. 多条件一致性约束
class DiffATR:
    """
    DiffATR: 条件感知对抗扩散训练
    """
    
    def __init__(self, model):
        self.model = model
        self.epsilon = 8/255
        
    def conditional_adversarial_training(self, images, text_conditions, 
                                       class_conditions):
        """
        条件感知对抗训练
        """
        # 生成对抗图像
        adv_images = self.generate_conditional_adversarial(
            images, text_conditions, class_conditions
        )
        
        # 多条件一致性损失
        loss = self.compute_conditional_loss(
            adv_images, text_conditions, class_conditions
        )
        
        return loss
    
    def generate_conditional_adversarial(self, images, text_c, class_c):
        """
        条件感知对抗图像生成
        """
        delta = torch.zeros_like(images).requires_grad_(True)
        optimizer = torch.optim.Adam([delta], lr=0.01)
        
        for step in range(10):
            optimizer.zero_grad()
            
            # 获取文本和类别的嵌入
            text_emb = self.model.encode_text(text_c)
            class_emb = self.model.encode_class(class_c)
            
            # 合并条件
            combined_emb = torch.cat([text_emb, class_emb], dim=-1)
            
            # 时间步
            t = torch.randint(0, self.model.num_timesteps, (images.shape[0],))
            
            # 噪声化
            noise = torch.randn_like(images)
            noisy = self.model.add_noise(images + delta, t, noise)
            
            # 预测
            pred_noise = self.model.noise_predictor(noisy, t, combined_emb)
            
            # 条件感知攻击损失
            loss = self.conditional_attack_loss(
                images + delta, pred_noise, noise, text_c, class_c
            )
            
            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                delta.copy_(torch.clamp(delta, -self.epsilon, self.epsilon))
        
        return images + delta.detach()
    
    def conditional_attack_loss(self, images, pred_noise, true_noise, 
                                text_c, class_c):
        """
        条件感知攻击损失
        """
        # 噪声预测损失(增加预测误差)
        noise_loss = -torch.norm(pred_noise - true_noise) ** 2
        
        # 文本条件一致性(最大化与文本的不一致)
        image_features = self.model.extract_image_features(images)
        text_features = self.model.extract_text_features(text_c)
        text_alignment = torch.cosine_similarity(image_features, text_features).mean()
        text_loss = -text_alignment
        
        # 类别条件一致性
        class_pred = self.model.classify(images)
        class_loss = -F.cross_entropy(class_pred, class_c)
        
        # 综合损失
        return noise_loss + 0.5 * text_loss + 0.3 * class_loss

3. 渐进式对抗训练

3.1 课程对抗训练

核心思想:从易到难,逐步增加对抗攻击强度。

class CurriculumAdversarialTraining:
    """
    课程对抗训练
    
    阶段:
    1. 干净样本训练(无攻击)
    2. 弱攻击训练(小的epsilon)
    3. 强攻击训练(大的epsilon)
    """
    
    def __init__(self, model, max_epsilon=8/255, total_steps=100000):
        self.model = model
        self.max_epsilon = max_epsilon
        self.total_steps = total_steps
        
    def get_current_epsilon(self, step):
        """
        根据训练进度调整epsilon
        """
        progress = step / self.total_steps
        
        if progress < 0.2:
            # 阶段1: 干净训练
            return 0
        elif progress < 0.5:
            # 阶段2: 过渡期
            return self.max_epsilon * 0.3
        elif progress < 0.8:
            # 阶段3: 稳健增长
            return self.max_epsilon * (0.3 + 0.4 * (progress - 0.5) / 0.3)
        else:
            # 阶段4: 全强度
            return self.max_epsilon
            
    def train(self, dataloader):
        """
        课程训练
        """
        global_step = 0
        
        for epoch in range(self.num_epochs):
            for batch in dataloader:
                images, conditions = batch
                
                # 获取当前epsilon
                epsilon = self.get_current_epsilon(global_step)
                
                if epsilon == 0:
                    # 干净训练
                    loss = self.clean_training_step(images, conditions)
                else:
                    # 对抗训练
                    loss = self.adversarial_training_step(
                        images, conditions, epsilon
                    )
                
                self.optimizer.step()
                self.optimizer.zero_grad()
                
                global_step += 1
                
    def adversarial_training_step(self, images, conditions, epsilon):
        """
        对抗训练步骤
        """
        # 生成对抗图像
        adv_images = self.pgd_attack(images, conditions, epsilon)
        
        # 训练损失
        loss = self.compute_training_loss(adv_images, conditions)
        
        # 额外:干净样本的重建损失
        clean_loss = self.compute_reconstruction_loss(images, conditions)
        
        return loss + 0.1 * clean_loss

3.2 渐进式攻击强度

class ProgressiveAdversarialTraining:
    """
    渐进式对抗训练
    
    在训练过程中逐步增加攻击强度和攻击步数
    """
    
    def __init__(self, model):
        self.model = model
        
    def train_step(self, images, conditions, step):
        """
        渐进式训练步骤
        """
        # 渐进式参数
        epsilon = self.schedule_epsilon(step)
        alpha = self.schedule_alpha(step)
        num_steps = self.schedule_num_steps(step)
        
        # 生成对抗图像
        adv_images = self.pgd_attack(
            images, conditions, epsilon, alpha, num_steps
        )
        
        # 计算损失
        return self.compute_loss(adv_images, conditions)
    
    def schedule_epsilon(self, step):
        """
        Epsilon调度
        """
        # Warmup后线性增长
        warmup_steps = 5000
        max_epsilon = 8/255
        
        if step < warmup_steps:
            return 0
        else:
            progress = min(1, (step - warmup_steps) / 50000)
            return max_epsilon * progress
    
    def schedule_alpha(self, step):
        """
        步长调度
        """
        return 1.0 / 255
    
    def schedule_num_steps(self, step):
        """
        攻击步数调度
        """
        warmup_steps = 5000
        
        if step < warmup_steps:
            return 1
        else:
            return min(50, 1 + (step - warmup_steps) // 1000)

4. 潜空间对抗训练

4.1 潜空间攻击

对于Latent Diffusion Models,可以在VAE的潜空间进行对抗训练:

class LatentSpaceAdversarialTraining:
    """
    潜空间对抗训练
    
    优势:
    1. 降低计算成本
    2. 直接作用于生成网络
    3. 绕过像素空间的过滤
    """
    
    def __init__(self, vae, unet, epsilon_latent=1.0):
        self.vae = vae
        self.unet = unet
        self.epsilon_latent = epsilon_latent
        
    def train_step(self, images, conditions):
        """
        潜空间对抗训练步骤
        """
        # Step 1: 编码到潜空间
        with torch.no_grad():
            latents = self.vae.encode(images).latent_dist.sample()
            
        # Step 2: 在潜空间生成对抗扰动
        delta_latent = self.generate_latent_adversarial(
            latents, conditions
        )
        
        # Step 3: 解码对抗潜变量
        adv_latents = latents + delta_latent
        adv_images = self.vae.decode(adv_latents)
        
        # Step 4: 计算损失
        loss = self.compute_latent_loss(
            latents, delta_latent, adv_images, conditions
        )
        
        return loss
    
    def generate_latent_adversarial(self, latents, conditions):
        """
        在潜空间生成对抗扰动
        """
        delta = torch.zeros_like(latents).requires_grad_(True)
        optimizer = torch.optim.Adam([delta], lr=0.1)
        
        for step in range(20):
            optimizer.zero_grad()
            
            # 时间步
            t = torch.randint(0, self.unet.num_timesteps, (latents.shape[0],))
            
            # 噪声化
            noise = torch.randn_like(latents)
            noisy_latents = self.unet.add_noise(latents + delta, t, noise)
            
            # 预测噪声
            pred_noise = self.unet.noise_predictor(
                noisy_latents, t, conditions
            )
            
            # 攻击损失
            loss = -torch.norm(pred_noise - noise) ** 2
            
            # 潜空间扰动正则化
            loss += 0.01 * torch.norm(delta, p=2)
            
            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                delta.copy_(torch.clamp(delta, -self.epsilon_latent, 
                                       self.epsilon_latent))
        
        return delta.detach()
    
    def compute_latent_loss(self, latents, delta, adv_images, conditions):
        """
        计算潜空间对抗训练损失
        """
        # 1. 潜空间扰动损失
        latent_loss = torch.norm(delta, p=2)
        
        # 2. 重建损失
        clean_recon = self.vae.decode(latents)
        adv_recon = adv_images
        recon_loss = torch.norm(clean_recon - adv_recon, p=2)
        
        # 3. CLIP对齐损失
        clip_loss = -self.compute_clip_similarity(adv_images, conditions)
        
        return latent_loss + recon_loss + 0.5 * clip_loss

4.2 跨注意力空间训练

class CrossAttentionAdversarialTraining:
    """
    交叉注意力空间对抗训练
    """
    
    def __init__(self, model):
        self.model = model
        
    def train_step(self, images, conditions):
        """
        交叉注意力对抗训练
        """
        # 获取注意力图
        attention_maps = self.get_attention_maps(images, conditions)
        
        # 生成对抗注意力扰动
        delta_attention = self.generate_attention_adversarial(
            attention_maps, conditions
        )
        
        # 应用扰动并计算损失
        loss = self.compute_attention_loss(
            attention_maps, delta_attention, conditions
        )
        
        return loss
    
    def get_attention_maps(self, images, conditions):
        """
        获取交叉注意力图
        """
        attention_maps = []
        
        def hook_fn(module, input, output):
            attention_maps.append(output[0].detach())
            
        hooks = []
        for block in self.model.unet.cross_attention_blocks:
            handle = block.register_forward_hook(hook_fn)
            hooks.append(handle)
            
        with torch.no_grad():
            _ = self.model.forward(images, conditions)
            
        for handle in hooks:
            handle.remove()
            
        return attention_maps
    
    def generate_attention_adversarial(self, attention_maps, conditions):
        """
        生成对抗注意力扰动
        """
        # 对每个注意力图独立攻击
        delta_maps = []
        
        for attn_map in attention_maps:
            delta = torch.zeros_like(attn_map).requires_grad_(True)
            optimizer = torch.optim.Adam([delta], lr=0.01)
            
            for _ in range(10):
                optimizer.zero_grad()
                
                # 扰动后的注意力
                perturbed_attn = attn_map + delta
                
                # 攻击损失:最大化与目标注意力的差异
                target_attn = self.get_target_attention(conditions)
                loss = -torch.norm(perturbed_attn - target_attn) ** 2
                
                loss.backward()
                optimizer.step()
                
                with torch.no_grad():
                    delta.copy_(torch.clamp(delta, -0.1, 0.1))
            
            delta_maps.append(delta.detach())
            
        return delta_maps

5. 鲁棒性评估

5.1 评估指标

def evaluate_robustness(model, test_images, test_conditions, epsilon=8/255):
    """
    评估模型鲁棒性
    """
    metrics = {
        'clean_fid': 0.0,
        'adv_fid': 0.0,
        'asr': 0.0,
        'lpips_distance': 0.0,
        'semantic_preservation': 0.0,
    }
    
    num_samples = len(test_images)
    
    for image, condition in zip(test_images, test_conditions):
        # 干净样本FID
        clean_gen = model.generate(image, condition)
        metrics['clean_fid'] += compute_fid(clean_gen)
        
        # 生成对抗样本
        adv_image = pgd_attack(model, image, condition, epsilon)
        
        # 对抗样本FID
        adv_gen = model.generate(adv_image, condition)
        metrics['adv_fid'] += compute_fid(adv_gen)
        
        # 攻击成功率
        metrics['asr'] += compute_asr(clean_gen, adv_gen, condition)
        
        # 感知距离
        metrics['lpips_distance'] += compute_lpips(clean_gen, adv_gen)
        
        # 语义保持率
        metrics['semantic_preservation'] += compute_semantic_sim(
            clean_gen, adv_gen
        )
    
    for key in metrics:
        metrics[key] /= num_samples
        
    return metrics

6. 代码实现总结

6.1 完整训练循环

def adversarial_training_loop(model, dataloader, num_epochs=10):
    """
    完整对抗训练循环
    """
    trainer = AdvDM(model, epsilon=8/255)
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        
        for batch_idx, (images, conditions) in enumerate(dataloader):
            loss = trainer.train_step(images, conditions)
            epoch_loss += loss
            
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss:.4f}")
        
        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch} Average Loss: {avg_loss:.4f}")
        
        # 定期评估鲁棒性
        if (epoch + 1) % 5 == 0:
            metrics = evaluate_robustness(model, val_images, val_conditions)
            print(f"Robustness Metrics: {metrics}")

7. 总结

对抗训练方法对比

方法优点缺点适用场景
AdvDM简单有效计算成本高通用场景
DiffATR条件感知实现复杂条件生成
课程训练稳定收敛训练时间长大规模训练
潜空间训练计算高效攻击面受限LDM模型
注意力训练细粒度控制难以优化特定攻击

未来方向

  1. 更高效的对抗训练:减少训练成本
  2. 理论保证:建立对抗训练的收敛理论
  3. 自适应攻击:根据模型弱点自适应攻击
  4. 跨模型迁移:研究对抗样本在不同模型间的迁移

参考资料

Footnotes

  1. [ICLR 2024] AdvDM: Adversarial Training for Diffusion Models