联邦学习拜占庭攻击与防御

拜占庭攻击是联邦学习面临的最严重安全威胁之一。恶意客户端可以发送伪造的模型更新,干扰全局模型的训练过程,甚至导致模型完全失效。本章深入探讨拜占庭攻击的类型、防御机制和最新研究成果。

1. 拜占庭攻击概述

1.1 威胁模型定义

拜占庭将军问题是分布式系统中的经典问题:在存在恶意节点的情况下,如何达成共识。联邦学习中的拜占庭攻击指恶意客户端故意发送错误的模型更新来破坏学习过程。

为客户端集合, 为恶意客户端集合。恶意客户端可以:

  • 发送任意伪造的梯度 而非真实梯度
  • 协调多个恶意客户端发送一致的错误更新
  • 窃听或篡改诚实客户端的通信

1.2 攻击分类

数据投毒攻击(Data Poisoning)

恶意客户端在本地训练数据中注入错误标签或构造对抗样本:

# 数据投毒示意
def data_poisoning_attack(dataset, poison_ratio=0.1, target_label=0):
    poisoned_data = []
    for x, y in dataset:
        if random.random() < poison_ratio:
            # 将样本标签翻转为目标标签
            poisoned_data.append((x, target_label))
        else:
            poisoned_data.append((x, y))
    return poisoned_data

影响:降低模型在特定类别上的准确率

模型投毒攻击(Model Poisoning)

恶意客户端直接伪造模型参数,而非修改训练数据:

# 模型投毒示意
def model_poisoning_attack(global_model, epsilon=10.0):
    """发送放大的恶意梯度"""
    poisoned_update = -epsilon * torch.sign(global_model.weight)
    return poisoned_update

影响:可以完全破坏模型收敛,甚至使模型在所有数据上失效

1.3 经典攻击方法

攻击方法描述威胁等级
Label Flipping将所有标签翻转为指定类别
Sign Flipping反转梯度符号
Gradient Boosting放大梯度幅值
Alternating Minimization交替优化训练损失和攻击目标极高

2. 经典鲁棒聚合算法

2.1 FedAvg的脆弱性

标准FedAvg对拜占庭攻击极为脆弱:

即使只有一个恶意客户端,也可以通过发送极大或极小的梯度来主导聚合结果。

2.2 Krum与Multi-Krum

Krum算法1通过选择与其它更新最接近的梯度:

def krum_update(gradients, n_attackers):
    """
    Krum聚合算法
    gradients: 客户端梯度列表
    n_attackers: 估计的恶意客户端数量
    """
    n = len(gradients)
    # 计算两两之间的欧氏距离
    distances = torch.cdist(torch.stack(gradients), torch.stack(gradients))
    
    # 对每个客户端,计算到最近n-f-2个客户端的距离和
    scores = []
    for i in range(n):
        sorted_dist = torch.sort(distances[i])[0]
        score = sorted_dist[:n - n_attackers - 2].sum()
        scores.append(score)
    
    # 选择得分最低的客户端
    selected_idx = torch.argmin(torch.tensor(scores))
    return gradients[selected_idx]

Multi-Krum选择多个”最诚实”的梯度进行平均:

其中 是选中的诚实客户端集合。

局限性

  • 计算复杂度
  • 只在恶意客户端比例 时有效
  • 对协调攻击(colluding attacks)脆弱

2.3 Trimmed Mean

Trimmed Mean对每个坐标分别裁剪极端值:

def trimmed_mean(gradients, trim_ratio=0.1):
    """
    Trimmed Mean聚合
    trim_ratio: 裁剪比例(如0.1表示每端裁剪10%)
    """
    n = len(gradients)
    d = len(gradients[0])
    trim_count = int(n * trim_ratio)
    
    aggregated = []
    for dim in range(d):
        # 收集该维度的所有值
        dim_values = [g[dim] for g in gradients]
        # 排序并裁剪
        sorted_values = sorted(dim_values)
        trimmed_values = sorted_values[trim_count:-trim_count]
        # 计算均值
        aggregated.append(sum(trimmed_values) / len(trimmed_values))
    
    return torch.tensor(aggregated)

优点:计算效率高,每个坐标独立处理
缺点:在非IID数据下可能失效

2.4 Coordinate-wise Median

对每个坐标单独取中位数:

时,坐标中位数能够抵御拜占庭攻击。

2.5 Geometric Median (RFA)

Geometric Median是最小化到所有点距离和的点:

def geometric_median(gradients, max_iter=100, tol=1e-6):
    """
    几何中位数聚合(Weiszfeld算法)
    """
    # 初始化为均值
    w = torch.mean(torch.stack(gradients), dim=0)
    
    for _ in range(max_iter):
        # 计算权重
        distances = torch.stack([torch.norm(w - g) for g in gradients])
        distances = torch.clamp(distances, min=1e-10)
        weights = 1.0 / distances
        
        # 更新w
        w_new = sum(w_i * d_i for w_i, d_i in zip(gradients, weights))
        w_new = w_new / weights.sum()
        
        if torch.norm(w_new - w) < tol:
            break
        w = w_new
    
    return w

3. 新型拜占庭鲁棒方法

3.1 FLTG:角度防御与非IID感知加权

FLTG2提出基于角度的防御机制:

def fltg_aggregation(global_model, gradients, server_dataset, 
                      clip_threshold=0.1, reference_weight=0.3):
    """
    FLTG: Byzantine-Robust Federated Learning via 
    Angle-Based Defense and Non-IID-Aware Weighting
    """
    # 步骤1:基于余弦相似度过滤
    global_vec = global_model.weight.flatten()
    valid_gradients = []
    
    for g in gradients:
        cos_sim = torch.nn.functional.cosine_similarity(
            g.flatten(), global_vec, dim=0
        )
        # ReLU裁剪的余弦相似度
        clipped_sim = torch.clamp(cos_sim, min=0)
        if clipped_sim > clip_threshold:
            valid_gradients.append(g)
    
    if len(valid_gradients) == 0:
        return torch.mean(torch.stack(gradients), dim=0)
    
    # 步骤2:动态参考选择
    # 选择与全局模型最接近的客户端作为参考
    similarities = [torch.norm(g - global_vec) for g in valid_gradients]
    reference_idx = torch.argmin(torch.tensor(similarities))
    reference = valid_gradients[reference_idx]
    
    # 步骤3:基于角度偏差加权
    weights = []
    for g in valid_gradients:
        angle = torch.acos(torch.clamp(
            torch.nn.functional.cosine_similarity(g, reference, dim=0),
            min=-1, max=1
        ))
        # 权重与角度偏差成反比
        weight = 1.0 / (1.0 + angle)
        weights.append(weight)
    
    weights = torch.tensor(weights) / sum(weights)
    
    # 步骤4:幅值归一化
    aggregated = sum(w * g for w, g in zip(weights, valid_gradients))
    aggregated = aggregated / torch.norm(aggregated) * torch.norm(reference)
    
    return aggregated

核心思想

  • 利用服务器端干净数据集排除错位的更新
  • 动态选择参考客户端以缓解非IID偏差
  • 幅值归一化抑制恶意缩放

3.2 FedGuard:成员推断增强防御

FedGuard3利用成员推断来检测投毒模型:

def fedguard_aggregation(gradients, server_minibatch, threshold=0.5):
    """
    FedGuard: 利用服务器指定数据的小批量检测投毒模型
    """
    client_scores = []
    
    for g in gradients:
        # 评估客户端模型在服务器小批量上的置信度
        model = reconstruct_model(g)
        confidence = evaluate_confidence(model, server_minibatch)
        
        # 投毒模型的置信度会显著下降
        if confidence < threshold:
            client_scores.append(0.0)  # 排除投毒客户端
        else:
            client_scores.append(confidence)
    
    # 加权聚合
    weights = torch.tensor(client_scores)
    weights = weights / weights.sum()
    aggregated = sum(w * g for w, g in zip(weights, gradients))
    
    return aggregated

关键发现:投毒模型在服务器指定数据上的置信度会显著下降

3.3 ProDiGy:双重评分系统

ProDiGy4提出邻近性和差异性双重评分:

  • 邻近性:评估客户端梯度与多数诚实客户端的相似程度
  • 差异性:检测可疑的一致性(攻击者可能发送相同或相似的投毒更新)
def prodigy_aggregation(gradients, alpha=0.5, beta=0.5, epsilon=0.1):
    """
    ProDiGy: Proximity- and Dissimilarity-Based 
    Byzantine-Robust Federated Learning
    """
    n = len(gradients)
    
    # 计算成对距离矩阵
    dist_matrix = torch.cdist(
        torch.stack([g.flatten() for g in gradients]),
        torch.stack([g.flatten() for g in gradients])
    )
    
    scores = []
    for i in range(n):
        # 邻近性得分:与其它客户端的平均距离
        proximity = dist_matrix[i].sum() / (n - 1)
        
        # 差异性得分:与其它客户端的差异程度
        # 攻击者可能发送相似的投毒更新
        similarity_to_others = []
        for j in range(n):
            if i != j:
                similarity = torch.exp(-dist_matrix[i, j])
                similarity_to_others.append(similarity)
        
        # 低的差异性意味着可能是协调攻击
        dissimilarity = 1 - torch.mean(torch.tensor(similarity_to_others))
        
        score = alpha * (1 / (proximity + epsilon)) + beta * dissimilarity
        scores.append(score)
    
    # 选择得分最高的客户端子集
    top_k = n - len([g for g in gradients if torch.norm(g) < epsilon])
    selected_indices = torch.argsort(torch.tensor(scores))[-top_k:]
    
    selected_gradients = [gradients[i] for i in selected_indices]
    return torch.mean(torch.stack(selected_gradients), dim=0)

3.4 FL-CLEANER:激活图误差聚类

FL-CLEANER5通过客户端激活图的重建误差检测攻击者:

def fl_cleaner_aggregation(gradients, trigger_set, cvae_model):
    """
    FL-CLEANER: 基于激活图重建误差的防御
    """
    client_scores = []
    
    for g in gradients:
        # 重构客户端模型的激活图
        model = reconstruct_model(g)
        reconstructed_activations = get_activations(model, trigger_set)
        
        # 使用CVAE计算重建误差
        with torch.no_grad():
            reconstructed = cvae_model(reconstructed_activations)
            error = torch.nn.functional.mse_loss(
                reconstructed, reconstructed_activations
            )
        
        client_scores.append(-error)  # 误差越小分数越高
    
    # 使用信任传播算法构建良性客户端聚类
    trust_scores = trust_propagation(client_scores, gradients)
    
    # 选择高信任分数的客户端
    valid_clients = [i for i, s in enumerate(trust_scores) if s > 0]
    
    return torch.mean(
        torch.stack([gradients[i] for i in valid_clients]), dim=0
    )

4. 分布式优化方法

4.1 PDMM:主从对偶方法

PDMM(Primal-Dual Method of Multipliers)通过共识机制天然抵御拜占庭攻击6

def pdmm_federated_update(local_grad, neighbor_grads, 
                          consensus_weight=1.0, byzantine_mask=None):
    """
    分布式优化:利用共识机制抵御拜占庭攻击
    """
    # 只与邻居(可能是诚实节点)通信
    valid_neighbors = [
        g for i, g in enumerate(neighbor_grads) 
        if byzantine_mask is None or byzantine_mask[i]
    ]
    
    if len(valid_neighbors) == 0:
        return local_grad
    
    # 共识项:推动与邻居达成一致
    consensus_term = sum(
        valid_neighbors[i] - local_grad 
        for i in range(len(valid_neighbors))
    ) / len(valid_neighbors)
    
    # 梯度更新 + 共识项
    updated_grad = local_grad + consensus_weight * consensus_term
    
    return updated_grad

优势

  • 不依赖中心化的聚合
  • 共识机制天然过滤异常值
  • 在恶意节点比例较高时仍能收敛

5. 认证鲁棒性理论

5.1 认证半径

认证鲁棒性保证在给定认证半径 内的扰动无法改变模型输出:

5.2 认证半径的估计

对于不同的聚合方法:

聚合方法认证半径
Krum
Trimmed Mean
Geometric Median
FLTrust依赖于参考数据集

5.3 WPCRA:全流程认证方法

WPCRA7提出ex-ante、ex-durante、ex-post三阶段认证:

def wpcr_aggregation(gradients, server_model, client_weights):
    """
    Whole-Process Certifiably Robust Aggregation
    """
    # Ex-ante:过滤明显异常的梯度
    ex_ante_grads = filter_by_magnitude(gradients, threshold=3.0)
    
    # Ex-durante:基于层级别的鲁棒性加权
    layer_weights = compute_layer_robustness(
        ex_ante_grads, server_model
    )
    ex_durante_grads = weighted_average(
        ex_ante_grads, layer_weights * client_weights
    )
    
    # Ex-post:最终安全检查
    final_grad = post_verification(ex_durante_grads, server_model)
    
    return final_grad

6. 非IID数据下的拜占庭鲁棒性

6.1 挑战

在非IID数据下,诚实客户端的梯度本身就可能差异很大,使得区分恶意和诚实更新变得困难。

6.2 BOBA方法

BOBA8专门处理标签偏斜下的拜占庭攻击:

def boba_aggregation(gradients, label_distribution, 
                      client_weights, f):
    """
    BOBA: Byzantine-Robust FL with Label Skewness
    两阶段方法
    """
    # 阶段1:基于标签分布调整权重
    adjusted_weights = []
    for i, g in enumerate(gradients):
        # 估计客户端的标签分布
        label_sim = estimate_label_similarity(
            label_distribution[i], 
            label_distribution[0]  # 参考分布
        )
        adjusted_weight = client_weights[i] * label_sim
        adjusted_weights.append(adjusted_weight)
    
    # 阶段2:过滤离群值
    weighted_grads = [
        adjusted_weights[i] * gradients[i] 
        for i in range(len(gradients))
    ]
    
    # 使用几何中位数估计
    geo_median = geometric_median(weighted_grads)
    
    # 计算到几何中位数的距离
    distances = [torch.norm(g - geo_median) for g in weighted_grads]
    
    # 选择距离最小的 n-f 个客户端
    sorted_indices = torch.argsort(torch.tensor(distances))
    selected_indices = sorted_indices[:len(gradients) - f]
    
    return torch.mean(
        torch.stack([gradients[i] for i in selected_indices]), dim=0
    )

7. 实践建议

7.1 防御策略选择

场景推荐方法
IID数据,少量恶意客户端Krum, Trimmed Mean
非IID数据FLTG, BOBA, ProDiGy
大量恶意客户端(>50%)FedGuard, PDMM
需要认证保证WPCRA

7.2 实现注意事项

# 实际部署建议
class ByzantineRobustAggregator:
    def __init__(self, method='auto', f_ratio=0.3):
        self.method = method
        self.f_ratio = f_ratio
        
        # 自动选择最佳方法
        if method == 'auto':
            self.aggregator = self._select_best_method()
    
    def _select_best_method(self):
        # 根据客户端数量、恶意比例等选择
        if self.f_ratio > 0.5:
            return FedGuardAggregator()
        elif self.f_ratio > 0.3:
            return ProDiGyAggregator()
        else:
            return TrimmedMeanAggregator()

7.3 性能基准

方法恶意比例CIFAR-10准确率通信复杂度
FedAvg0%85.2%O(nd)
FedAvg30%42.1%O(nd)
Krum30%78.3%O(n²d)
Trimmed Mean30%79.5%O(nd)
FLTG50%82.1%O(nd)
ProDiGy50%81.8%O(n²d)

8. 总结

拜占庭鲁棒性是联邦学习安全性的核心挑战。本章介绍了:

  1. 攻击类型:数据投毒和模型投毒
  2. 经典防御:Krum、Trimmed Mean、Geometric Median
  3. 新型方法:FLTG、FedGuard、ProDiGy、FL-CLEANER
  4. 分布式优化:PDMM共识机制
  5. 认证理论:WPCRA全流程认证
  6. 非IID挑战:BOBA等专用方法

未来方向包括:

  • 更高效的鲁棒聚合算法
  • 结合差分隐私的防御
  • 去中心化拜占庭鲁棒性

参考资料


相关主题federated-learning-fundamentalsfederated-learning-non-iid-strategiesfederated-learning-privacy-attacksadversarial-robustness-fundamentals

Footnotes

  1. Blanchard et al. “Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent” (NIPS 2017)

  2. FLTG: “Byzantine-Robust Federated Learning via Angle-Based Defense” (arXiv:2505.12851)

  3. FedGuard: “A Diverse-Byzantine-Robust Mechanism for Federated Learning with Major Malicious Clients” (arXiv:2508.00636)

  4. ProDiGy: “Proximity- and Dissimilarity-Based Byzantine-Robust Federated Learning” (arXiv:2509.09534)

  5. FL-CLEANER: “Byzantine and Backdoor Defense by Clustering Errors of Activation Maps in Non-IID FL” (arXiv:2501.12123)

  6. PDMM: “Byzantine-Resilient Federated Learning via Distributed Optimization” (arXiv:2503.10792)

  7. WPCRA: “A Whole-Process Certifiably Robust Aggregation Method Against Backdoor Attacks” (arXiv:2407.00719)

  8. BOBA: “Byzantine-Robust Federated Learning with Label Skewness” (ICML 2024)