1. 引言

在真实的联邦学习系统中,参与的客户端可能存在恶意行为。拜占庭攻击(Byzantine Attack)是最具破坏性的攻击之一,恶意客户端可以上传任意错误的模型更新,从而破坏全局模型的学习。本节介绍拜占庭攻击的机制和防御策略。


2. 拜占庭攻击模型

2.1 攻击者能力假设

拜占庭攻击者(Byzantine Attacker)可以:

  • 完全控制部分客户端
  • 上传任意-crafted模型更新
  • 协调多个恶意客户端的攻击策略

2.2 攻击类型

攻击类型描述目标
标签翻转攻击修改数据标签为错误类别使模型产生错误预测
梯度反转攻击上传负梯度破坏模型学习
高斯攻击上传随机高斯噪声降低模型性能
定向攻击精心-crafted梯度针对特定输出
后门攻击植入隐藏的后门在特定输入上触发恶意行为

2.3 攻击数学模型

恶意客户端 上传的更新为:

攻击目标:使得聚合后的模型 偏离真实最优:


3. 标准聚合的风险

3.1 FedAvg的脆弱性

FedAvg的聚合规则:

当存在恶意客户端时:

  • 恶意更新会直接污染全局模型
  • 即使只有1个恶意客户端(),也可能导致10%的误差累积

3.2 攻击效果示例

def byzantine_attack_simulation():
    """
    模拟拜占庭攻击效果
    """
    # 假设有10个客户端,1个恶意
    honest_updates = [torch.randn(100) for _ in range(9)]
    malicious_update = torch.zeros(100)  # 上传零梯度
    
    # 正常聚合
    normal_aggregate = sum(honest_updates) / 10
    
    # 被攻击的聚合
    attacked_aggregate = (sum(honest_updates) + malicious_update) / 10
    
    # 恶意客户端上传相反方向的梯度
    malicious_update = -sum(honest_updates) / 9  # 完全抵消诚实更新
    worst_aggregate = (sum(honest_updates) + malicious_update) / 10
    
    return normal_aggregate, attacked_aggregate, worst_aggregate

4. 拜占庭鲁棒聚合算法

4.1 Krum聚合

核心思想:选择最接近其他更新的更新作为代表。

def krum_aggregate(updates, n_byzantine):
    """
    Krum聚合算法
    
    Args:
        updates: 客户端更新列表
        n_byzantine: 拜占庭客户端数量
    """
    n = len(updates)
    m = n - n_byzantine - 2  # 选择前m个
    
    # 计算距离矩阵
    distances = torch.zeros(n, n)
    for i in range(n):
        for j in range(i+1, n):
            d = torch.norm(updates[i] - updates[j])
            distances[i, j] = d
            distances[j, i] = d
    
    # 计算每个更新的分数(到其他更新的距离之和)
    scores = []
    for i in range(n):
        # 排除自己,选择最近的n-m个
        sorted_distances = sorted(distances[i])
        score = sum(sorted_distances[1:1+m])  # 排除自己
        scores.append(score)
    
    # 选择分数最小的更新
    selected_idx = np.argmin(scores)
    return updates[selected_idx]

4.2 Trimmed Mean聚合

核心思想:去掉最大和最小的更新,然后求平均。

def trimmed_mean_aggregate(updates, n_byzantine):
    """
    Trimmed Mean聚合算法
    
    每去掉 β = n_byzantine 个最大和最小的更新
    """
    n = len(updates)
    beta = n_byzantine
    
    # 对每个参数维度分别处理
    aggregated = []
    for dim in range(updates[0].shape[0]):
        # 收集所有客户端在该维度的值
        values = [u[dim] for u in updates]
        
        # 排序
        sorted_values = sorted(values)
        
        # 去掉beta个最大和最小的
        trimmed = sorted_values[beta: n-beta]
        
        # 求平均
        aggregated.append(sum(trimmed) / len(trimmed))
    
    return torch.stack(aggregated)

4.3 Median聚合

def median_aggregate(updates):
    """
    中位数聚合
    """
    n_dims = updates[0].shape[0]
    aggregated = []
    
    for dim in range(n_dims):
        values = [u[dim] for u in updates]
        aggregated.append(torch.median(torch.stack(values)))
    
    return torch.stack(aggregated)

4.4 分布式鲁棒优化(DRACO)

DRACO使用Flaml距离来检测异常:

def draco_aggregate(updates, n_byzantine):
    """
    DRACO: 基于分布式鲁棒优化的聚合
    """
    n = len(updates)
    
    # 计算中心
    center = sum(updates) / n
    
    # 计算各更新到中心的距离
    distances = [torch.norm(u - center) for u in updates]
    
    # 使用Huber损失识别异常值
    threshold = np.percentile(distances, 75)  # 75th percentile
    
    # 只聚合距离在阈值内的更新
    robust_updates = [
        u for u, d in zip(updates, distances) 
        if d <= threshold
    ]
    
    if len(robust_updates) > 0:
        return sum(robust_updates) / len(robust_updates)
    else:
        return center

5. 基于统计的防御方法

5.1 检测异常的统计量

统计量计算方法异常检测规则
范数偏离均值超过阈值
方向方向不一致
方差异常大
余弦相似度与多数更新负相关

5.2 范数检测

def norm_based_filter(updates, n_byzantine, sigma=3):
    """
    基于范数的异常检测
    """
    # 计算更新范数
    norms = [torch.norm(u) for u in updates]
    mean_norm = sum(norms) / len(norms)
    std_norm = np.std(norms)
    
    # 计算阈值
    threshold = mean_norm + sigma * std_norm
    
    # 过滤
    filtered = [
        u for u, n in zip(updates, norms)
        if n <= threshold
    ]
    
    return filtered if len(filtered) > n_byzantine else updates

5.3 余弦相似度检测

def cosine_similarity_filter(updates, n_byzantine):
    """
    基于余弦相似度的异常检测
    """
    n = len(updates)
    
    # 计算平均方向
    mean_update = sum(updates) / n
    
    # 计算每个更新与平均方向的余弦相似度
    similarities = []
    for u in updates:
        cos_sim = torch.sum(u * mean_update) / (
            torch.norm(u) * torch.norm(mean_update) + 1e-8
        )
        similarities.append(cos_sim)
    
    # 选择相似度最高的 (n - n_byzantine) 个更新
    sorted_indices = np.argsort(similarities)[::-1]  # 降序
    selected_indices = sorted_indices[:-n_byzantine]
    
    selected_updates = [updates[i] for i in selected_indices]
    
    return selected_updates

6. 基于学习的防御方法

6.1 自适应聚合

class AdaptiveAggregator:
    def __init__(self, n_byzantine):
        self.n_byzantine = n_byzantine
        self.history = []
    
    def aggregate(self, updates):
        """自适应聚合"""
        # Step 1: 检测异常
        scores = self.compute_reliability_scores(updates)
        
        # Step 2: 分配权重
        weights = self.compute_weights(scores)
        
        # Step 3: 加权聚合
        aggregated = sum(w * u for w, u in zip(weights, updates))
        
        return aggregated
    
    def compute_reliability_scores(self, updates):
        """计算可靠性分数"""
        n = len(updates)
        
        # 基于一致性的分数
        consistency_scores = []
        for i in range(n):
            other_updates = [updates[j] for j in range(n) if j != i]
            similarity = self.compute_avg_similarity(updates[i], other_updates)
            consistency_scores.append(similarity)
        
        # 基于范数的分数
        norm_scores = []
        norms = [torch.norm(u) for u in updates]
        mean_norm = sum(norms) / n
        for n_i in norms:
            score = 1 / (1 + abs(n_i - mean_norm))
            norm_scores.append(score)
        
        # 组合分数
        scores = [0.5 * c + 0.5 * n for c, n in zip(consistency_scores, norm_scores)]
        return scores
    
    def compute_avg_similarity(self, update, other_updates):
        """计算与其他更新的平均相似度"""
        similarities = []
        for other in other_updates:
            cos_sim = torch.sum(update * other) / (
                torch.norm(update) * torch.norm(other) + 1e-8
            )
            similarities.append(cos_sim)
        return sum(similarities) / len(similarities)
    
    def compute_weights(self, scores):
        """基于分数计算权重"""
        total = sum(scores)
        return [s / total for s in scores]

6.2 RFA(Robust Federated Aggregation)

RFA使用几何中位数的近似:

def rfa_aggregate(updates, iterations=10, lr=0.5):
    """
    RFA: 几何中位数近似
    """
    n = len(updates)
    
    # 初始化中心为均值
    center = sum(updates) / n
    
    # 迭代优化
    for _ in range(iterations):
        # 计算权重(到中心的距离的倒数)
        distances = [torch.norm(u - center) for u in updates]
        weights = [1 / (d + 1e-8) for d in distances]
        
        # 归一化权重
        total_weight = sum(weights)
        weights = [w / total_weight for w in weights]
        
        # 更新中心
        center = sum(w * u for w, u in zip(weights, updates))
    
    return center

7. 后门攻击与防御

7.1 后门攻击机制

后门攻击(Backdoor Attack)旨在在全局模型中植入一个隐藏的触发器

def backdoor_attack(model_update, trigger_pattern):
    """
    向模型更新中植入后门
    """
    # 在特定参数上注入后门信号
    poisoned_update = model_update.clone()
    
    # 添加触发器相关的信号
    for i, param in enumerate(poisoned_update):
        if i in trigger_pattern['params']:
            param += trigger_pattern['magnitude'] * torch.randn_like(param)
    
    return poisoned_update

7.2 后门防御方法

防御方法原理优缺点
特征压缩减少高频成分可能降低主任务性能
范数裁剪限制单次更新幅度对小幅度后门效果有限
知识蒸馏用干净数据蒸馏计算开销大
客户端验证验证客户端提交需要额外数据

7.3 ZENO防御

def zeno_defense(updates, client_data, n_byzantine):
    """
    ZENO: 基于梯度方向的防御
    """
    n = len(updates)
    
    # 评估每个更新的质量
    scores = []
    for update in updates:
        # 计算更新方向的一致性
        direction_score = compute_direction_score(update, client_data)
        scores.append(direction_score)
    
    # 选择分数最高的更新
    sorted_indices = np.argsort(scores)[::-1]
    selected = sorted_indices[:-n_byzantine]
    
    return sum(updates[i] for i in selected) / len(selected)

8. 组合隐私与鲁棒性

8.1 DP-BREM算法

同时实现差分隐私和拜占庭鲁棒性:

def dp_brem_aggregate(updates, n_byzantine, epsilon, delta):
    """
    DP-BREM: 差分隐私 + 拜占庭鲁棒
    
    来自 USENIX Security 2025
    """
    # Step 1: 拜占庭过滤
    filtered = byzantine_filter(updates, n_byzantine)
    
    # Step 2: 裁剪
    clipped = clip_updates(filtered, clip_norm)
    
    # Step 3: 添加噪声
    sigma = clip_norm * np.sqrt(2 * np.log(1.25 / delta)) / epsilon
    noisy = add_gaussian_noise(clipped, sigma)
    
    return noisy

8.2 攻击防御权衡

目标方法权衡
仅隐私本地DP效用损失大
仅鲁棒Krum/Trimmed Mean对正常数据可能有偏差
两者兼顾DP-BREM需要更大的隐私预算

9. 参考文献


10. 相关主题