1. 引言
在真实的联邦学习系统中,参与的客户端可能存在恶意行为。拜占庭攻击(Byzantine Attack)是最具破坏性的攻击之一,恶意客户端可以上传任意错误的模型更新,从而破坏全局模型的学习。本节介绍拜占庭攻击的机制和防御策略。
2. 拜占庭攻击模型
2.1 攻击者能力假设
拜占庭攻击者(Byzantine Attacker)可以:
- 完全控制部分客户端
- 上传任意-crafted模型更新
- 协调多个恶意客户端的攻击策略
2.2 攻击类型
| 攻击类型 | 描述 | 目标 |
|---|---|---|
| 标签翻转攻击 | 修改数据标签为错误类别 | 使模型产生错误预测 |
| 梯度反转攻击 | 上传负梯度 | 破坏模型学习 |
| 高斯攻击 | 上传随机高斯噪声 | 降低模型性能 |
| 定向攻击 | 精心-crafted梯度 | 针对特定输出 |
| 后门攻击 | 植入隐藏的后门 | 在特定输入上触发恶意行为 |
2.3 攻击数学模型
恶意客户端 上传的更新为:
攻击目标:使得聚合后的模型 偏离真实最优:
3. 标准聚合的风险
3.1 FedAvg的脆弱性
FedAvg的聚合规则:
当存在恶意客户端时:
- 恶意更新会直接污染全局模型
- 即使只有1个恶意客户端(),也可能导致10%的误差累积
3.2 攻击效果示例
def byzantine_attack_simulation():
"""
模拟拜占庭攻击效果
"""
# 假设有10个客户端,1个恶意
honest_updates = [torch.randn(100) for _ in range(9)]
malicious_update = torch.zeros(100) # 上传零梯度
# 正常聚合
normal_aggregate = sum(honest_updates) / 10
# 被攻击的聚合
attacked_aggregate = (sum(honest_updates) + malicious_update) / 10
# 恶意客户端上传相反方向的梯度
malicious_update = -sum(honest_updates) / 9 # 完全抵消诚实更新
worst_aggregate = (sum(honest_updates) + malicious_update) / 10
return normal_aggregate, attacked_aggregate, worst_aggregate4. 拜占庭鲁棒聚合算法
4.1 Krum聚合
核心思想:选择最接近其他更新的更新作为代表。
def krum_aggregate(updates, n_byzantine):
"""
Krum聚合算法
Args:
updates: 客户端更新列表
n_byzantine: 拜占庭客户端数量
"""
n = len(updates)
m = n - n_byzantine - 2 # 选择前m个
# 计算距离矩阵
distances = torch.zeros(n, n)
for i in range(n):
for j in range(i+1, n):
d = torch.norm(updates[i] - updates[j])
distances[i, j] = d
distances[j, i] = d
# 计算每个更新的分数(到其他更新的距离之和)
scores = []
for i in range(n):
# 排除自己,选择最近的n-m个
sorted_distances = sorted(distances[i])
score = sum(sorted_distances[1:1+m]) # 排除自己
scores.append(score)
# 选择分数最小的更新
selected_idx = np.argmin(scores)
return updates[selected_idx]4.2 Trimmed Mean聚合
核心思想:去掉最大和最小的更新,然后求平均。
def trimmed_mean_aggregate(updates, n_byzantine):
"""
Trimmed Mean聚合算法
每去掉 β = n_byzantine 个最大和最小的更新
"""
n = len(updates)
beta = n_byzantine
# 对每个参数维度分别处理
aggregated = []
for dim in range(updates[0].shape[0]):
# 收集所有客户端在该维度的值
values = [u[dim] for u in updates]
# 排序
sorted_values = sorted(values)
# 去掉beta个最大和最小的
trimmed = sorted_values[beta: n-beta]
# 求平均
aggregated.append(sum(trimmed) / len(trimmed))
return torch.stack(aggregated)4.3 Median聚合
def median_aggregate(updates):
"""
中位数聚合
"""
n_dims = updates[0].shape[0]
aggregated = []
for dim in range(n_dims):
values = [u[dim] for u in updates]
aggregated.append(torch.median(torch.stack(values)))
return torch.stack(aggregated)4.4 分布式鲁棒优化(DRACO)
DRACO使用Flaml距离来检测异常:
def draco_aggregate(updates, n_byzantine):
"""
DRACO: 基于分布式鲁棒优化的聚合
"""
n = len(updates)
# 计算中心
center = sum(updates) / n
# 计算各更新到中心的距离
distances = [torch.norm(u - center) for u in updates]
# 使用Huber损失识别异常值
threshold = np.percentile(distances, 75) # 75th percentile
# 只聚合距离在阈值内的更新
robust_updates = [
u for u, d in zip(updates, distances)
if d <= threshold
]
if len(robust_updates) > 0:
return sum(robust_updates) / len(robust_updates)
else:
return center5. 基于统计的防御方法
5.1 检测异常的统计量
| 统计量 | 计算方法 | 异常检测规则 |
|---|---|---|
| 范数 | 偏离均值超过阈值 | |
| 方向 | 方向不一致 | |
| 方差 | 异常大 | |
| 余弦相似度 | 与多数更新负相关 |
5.2 范数检测
def norm_based_filter(updates, n_byzantine, sigma=3):
"""
基于范数的异常检测
"""
# 计算更新范数
norms = [torch.norm(u) for u in updates]
mean_norm = sum(norms) / len(norms)
std_norm = np.std(norms)
# 计算阈值
threshold = mean_norm + sigma * std_norm
# 过滤
filtered = [
u for u, n in zip(updates, norms)
if n <= threshold
]
return filtered if len(filtered) > n_byzantine else updates5.3 余弦相似度检测
def cosine_similarity_filter(updates, n_byzantine):
"""
基于余弦相似度的异常检测
"""
n = len(updates)
# 计算平均方向
mean_update = sum(updates) / n
# 计算每个更新与平均方向的余弦相似度
similarities = []
for u in updates:
cos_sim = torch.sum(u * mean_update) / (
torch.norm(u) * torch.norm(mean_update) + 1e-8
)
similarities.append(cos_sim)
# 选择相似度最高的 (n - n_byzantine) 个更新
sorted_indices = np.argsort(similarities)[::-1] # 降序
selected_indices = sorted_indices[:-n_byzantine]
selected_updates = [updates[i] for i in selected_indices]
return selected_updates6. 基于学习的防御方法
6.1 自适应聚合
class AdaptiveAggregator:
def __init__(self, n_byzantine):
self.n_byzantine = n_byzantine
self.history = []
def aggregate(self, updates):
"""自适应聚合"""
# Step 1: 检测异常
scores = self.compute_reliability_scores(updates)
# Step 2: 分配权重
weights = self.compute_weights(scores)
# Step 3: 加权聚合
aggregated = sum(w * u for w, u in zip(weights, updates))
return aggregated
def compute_reliability_scores(self, updates):
"""计算可靠性分数"""
n = len(updates)
# 基于一致性的分数
consistency_scores = []
for i in range(n):
other_updates = [updates[j] for j in range(n) if j != i]
similarity = self.compute_avg_similarity(updates[i], other_updates)
consistency_scores.append(similarity)
# 基于范数的分数
norm_scores = []
norms = [torch.norm(u) for u in updates]
mean_norm = sum(norms) / n
for n_i in norms:
score = 1 / (1 + abs(n_i - mean_norm))
norm_scores.append(score)
# 组合分数
scores = [0.5 * c + 0.5 * n for c, n in zip(consistency_scores, norm_scores)]
return scores
def compute_avg_similarity(self, update, other_updates):
"""计算与其他更新的平均相似度"""
similarities = []
for other in other_updates:
cos_sim = torch.sum(update * other) / (
torch.norm(update) * torch.norm(other) + 1e-8
)
similarities.append(cos_sim)
return sum(similarities) / len(similarities)
def compute_weights(self, scores):
"""基于分数计算权重"""
total = sum(scores)
return [s / total for s in scores]6.2 RFA(Robust Federated Aggregation)
RFA使用几何中位数的近似:
def rfa_aggregate(updates, iterations=10, lr=0.5):
"""
RFA: 几何中位数近似
"""
n = len(updates)
# 初始化中心为均值
center = sum(updates) / n
# 迭代优化
for _ in range(iterations):
# 计算权重(到中心的距离的倒数)
distances = [torch.norm(u - center) for u in updates]
weights = [1 / (d + 1e-8) for d in distances]
# 归一化权重
total_weight = sum(weights)
weights = [w / total_weight for w in weights]
# 更新中心
center = sum(w * u for w, u in zip(weights, updates))
return center7. 后门攻击与防御
7.1 后门攻击机制
后门攻击(Backdoor Attack)旨在在全局模型中植入一个隐藏的触发器:
def backdoor_attack(model_update, trigger_pattern):
"""
向模型更新中植入后门
"""
# 在特定参数上注入后门信号
poisoned_update = model_update.clone()
# 添加触发器相关的信号
for i, param in enumerate(poisoned_update):
if i in trigger_pattern['params']:
param += trigger_pattern['magnitude'] * torch.randn_like(param)
return poisoned_update7.2 后门防御方法
| 防御方法 | 原理 | 优缺点 |
|---|---|---|
| 特征压缩 | 减少高频成分 | 可能降低主任务性能 |
| 范数裁剪 | 限制单次更新幅度 | 对小幅度后门效果有限 |
| 知识蒸馏 | 用干净数据蒸馏 | 计算开销大 |
| 客户端验证 | 验证客户端提交 | 需要额外数据 |
7.3 ZENO防御
def zeno_defense(updates, client_data, n_byzantine):
"""
ZENO: 基于梯度方向的防御
"""
n = len(updates)
# 评估每个更新的质量
scores = []
for update in updates:
# 计算更新方向的一致性
direction_score = compute_direction_score(update, client_data)
scores.append(direction_score)
# 选择分数最高的更新
sorted_indices = np.argsort(scores)[::-1]
selected = sorted_indices[:-n_byzantine]
return sum(updates[i] for i in selected) / len(selected)8. 组合隐私与鲁棒性
8.1 DP-BREM算法
同时实现差分隐私和拜占庭鲁棒性:
def dp_brem_aggregate(updates, n_byzantine, epsilon, delta):
"""
DP-BREM: 差分隐私 + 拜占庭鲁棒
来自 USENIX Security 2025
"""
# Step 1: 拜占庭过滤
filtered = byzantine_filter(updates, n_byzantine)
# Step 2: 裁剪
clipped = clip_updates(filtered, clip_norm)
# Step 3: 添加噪声
sigma = clip_norm * np.sqrt(2 * np.log(1.25 / delta)) / epsilon
noisy = add_gaussian_noise(clipped, sigma)
return noisy8.2 攻击防御权衡
| 目标 | 方法 | 权衡 |
|---|---|---|
| 仅隐私 | 本地DP | 效用损失大 |
| 仅鲁棒 | Krum/Trimmed Mean | 对正常数据可能有偏差 |
| 两者兼顾 | DP-BREM | 需要更大的隐私预算 |
9. 参考文献
10. 相关主题
- federated-learning-fundamentals — 联邦学习基础
- federated-learning-privacy-dp — 差分隐私保护
- secure-aggregation — 安全聚合协议