1. 引言
联邦学习通过将数据保留在本地来保护隐私,但模型更新本身仍可能泄露信息。安全聚合(Secure Aggregation)旨在确保服务器只能看到聚合后的结果,而无法获知任何单个客户端的更新内容。
2. 安全聚合的威胁模型
2.1 威胁类型
| 威胁 | 攻击者能力 | 目标 |
|---|---|---|
| 诚实但好奇(Honest-but-Curious) | 遵守协议但试图从消息中推断信息 | 服务器 |
| 恶意服务器 | 偏离协议,收集额外信息 | 服务器 |
| 窃听者 | 截获通信内容 | 通信网络 |
| 共谋客户端 | 多个客户端联合推断其他客户端信息 | 客户端联盟 |
2.2 安全目标
- 机密性:服务器无法获知单个客户端的更新
- 完整性:服务器无法篡改聚合结果
- 正确性:聚合结果等于真实值的某种函数
3. 密码学基础
3.1 秘密分享(Secret Sharing)
Shamir秘密分享:将秘密 分成 份 ,满足:
- 任意 份可以恢复
- 任意少于 份无法获得 的任何信息
多项式构造:
其中 是随机系数。
3.2 不经意传输(Oblivious Transfer)
2选1不经意传输:
- 发送者有两个消息
- 接收者选择
- 接收者获得 ,但发送者不知道
- 接收者无法获得
3.3 同态加密(Homomorphic Encryption)
加法同态:给定 ,可以计算 而不解密。
常见方案:
- Paillier加密
- RSA加密
- 格基加密(更高效)
4. 安全聚合协议
4.1 基本安全聚合
def secure_aggregation_basic(updates, thresholds):
"""
基本安全聚合协议
Args:
updates: 客户端模型更新 {(client_id, update)}
thresholds: 每个客户端的阈值
"""
n = len(updates) # 客户端数量
p = 素数模数
# Step 1: 生成随机掩码
client_masks = {}
for i, (cid, _) in enumerate(updates):
# 客户端i生成随机掩码
mask =随机生成(len(update_bits))
client_masks[cid] = mask
# Step 2: 交换掩码(使用秘密分享)
received_masks = secure_mask_exchange(client_masks, n)
# Step 3: 添加掩码后的更新
masked_updates = {}
for cid, update in updates:
total_mask = sum(
received_masks[other_cid][cid]
for other_cid in received_masks
) % p
masked_updates[cid] = (update + total_mask) % p
# Step 4: 服务器聚合掩码更新
aggregated = sum(masked_updates.values()) % p
# Step 5: 客户端协作揭示掩码和
# 只有当至少threshold个客户端参与时才能揭示
if 参与客户端数 >= 阈值:
mask_sum = reveal_mask_sum(received_masks)
final_aggregation = (aggregated - mask_sum) % p
else:
final_aggregation = aggregated # 使用掩码聚合
return final_aggregation4.2 Bonawitz等人的安全聚合协议
Bonawitz等人(2017)提出的协议是实际系统中最常用的方案:
class BonawitzSecureAggregation:
def __init__(self, threshold, n_clients):
self.threshold = threshold
self.n_clients = n_clients
self.secret_shares = {}
def setup_phase(self, client_ids):
"""
设置阶段:建立通信拓扑
"""
# 为每个客户端创建秘密分享
for cid in client_ids:
# 生成随机密钥
key = generate_random_key()
# 将密钥分成n份
shares = shamir_split(key, self.n_clients, self.threshold)
self.secret_shares[cid] = shares
def masking_phase(self, client_id, model_update, peers):
"""
掩码阶段:添加随机掩码
"""
# 生成一次性掩码
mask = torch.randn_like(model_update)
# 与对等方交换掩码份额
received_shares = []
for peer_id in peers:
# 使用不经意传输交换掩码
share = oblivious_transfer(mask, peer_id)
received_shares.append(share)
# 计算总掩码
total_mask = sum(received_shares)
# 返回掩码后的更新
masked_update = model_update + total_mask
return masked_update, mask
def aggregation_phase(self, masked_updates):
"""
聚合阶段:服务器聚合
"""
return sum(masked_updates) / len(masked_updates)
def reveal_phase(self, client_ids, masks):
"""
揭示阶段:协作揭示掩码和
"""
# 收集有效客户端的掩码
total_mask = sum(masks[cid] for cid in client_ids)
# 服务器减去掩码和
return total_mask5. 高效安全聚合
5.1 通信效率优化
| 优化技术 | 效果 | 实现复杂度 |
|---|---|---|
| 预计算 | 减少在线通信 | 高 |
| 批处理 | 减少消息数量 | 中 |
| 压缩 | 减少传输数据量 | 低 |
| 稀疏化 | 只传输重要参数 | 中 |
5.2 预计算掩码
class PrecomputedSecureAggregation:
def __init__(self):
self.precomputed_masks = {}
def precompute_masks(self, client_id, n_peers, mask_size):
"""
离线阶段:预计算大量掩码
"""
masks = []
for _ in range(n_peers):
# 预计算随机掩码
mask = torch.randn(mask_size)
masks.append(mask)
self.precomputed_masks[client_id] = masks
def online_mixing(self, model_update, client_id):
"""
在线阶段:使用预计算的掩码
"""
masks = self.precomputed_masks[client_id]
# XOR或加法组合掩码
combined_mask = sum(masks) % (2**32)
return model_update + combined_mask5.3 星型拓扑优化
def star_topology_aggregation(updates, server_id, clients):
"""
星型拓扑:减少通信复杂度
复杂度从 O(n²) 降到 O(n)
"""
n = len(clients)
# Step 1: 客户端到服务器的掩码交换
for client in clients:
# 客户端生成随机数
client.random_value = generate_random()
# 客户端发送到服务器
send(client.random_value, server_id)
# Step 2: 服务器计算掩码和
mask_sum = sum(client.random_value for client in clients)
# Step 3: 服务器广播掩码和
broadcast(mask_sum, clients)
# Step 4: 各客户端计算最终掩码
for client in clients:
final_mask = mask_sum - client.random_value
client.masked_update = client.update + final_mask
send(client.masked_update, server_id)
# Step 5: 服务器聚合
return sum(client.masked_update for client in clients)6. 与差分隐私的结合
6.1 先聚合后扰动
def secure_aggregate_with_dp(updates, privacy_config):
"""
安全聚合 + 差分隐私
"""
# Step 1: 安全聚合得到真实和
true_sum = secure_aggregation(updates)
# Step 2: 添加DP噪声
sensitivity = 计算敏感性(updates)
sigma = sensitivity * np.sqrt(2 * np.log(1.25 / privacy_config['delta'])) / privacy_config['epsilon']
noisy_sum = true_sum + torch.randn_like(true_sum) * sigma
return noisy_sum6.2 先扰动后聚合
def dp_then_secure_aggregate(updates, privacy_config):
"""
差分隐私 + 安全聚合
每个客户端先扰动自己的更新,再进行安全聚合
"""
# Step 1: 各客户端添加本地DP噪声
noisy_updates = []
for update in updates:
sensitivity = 计算敏感性(update)
sigma = sensitivity * np.sqrt(2 * np.log(1.25 / privacy_config['delta'])) / privacy_config['epsilon']
noisy_update = update + torch.randn_like(update) * sigma
noisy_updates.append(noisy_update)
# Step 2: 安全聚合(不需要添加额外噪声)
return secure_aggregation(noisy_updates)7. 实际系统实现
7.1 Google的安全聚合实现
Google在TensorFlow Federated中实现了安全聚合:
import tensorflow_federated as tff
@tff.federated_computation(
tff.type_at_clients(tff.SequenceType(tf.float32)),
tff.type_at_server(tf.float32) # 阈值
)
def secure_mean(client_values, threshold):
"""
安全均值计算
"""
return tff.aggregators_secure_sum(
client_values,
noise_multiplier=0.1,
expected_clients_per_round=100,
bits_precision=16
)7.2 PySyft实现
import syft as sy
from syft.lib.python import Integer
class SecureFederatedLearning:
def __init__(self, threshold):
self.threshold = threshold
self.hook = sy.TorchHook(torch)
def secure_aggregate(self, workers, model_updates):
"""
使用PySyft进行安全聚合
"""
# Step 1: 指针化模型更新
pointers = []
for worker, update in zip(workers, model_updates):
ptr = update.send(worker)
pointers.append(ptr)
# Step 2: 使用安全聚合器
secure_sum = sy.FederatedSumMediator(
secure=True,
mechanism='secure_sum',
threshold=self.threshold
)
# Step 3: 执行安全聚合
aggregated = secure_sum.execute(pointers)
return aggregated.get()8. 安全与效率权衡
8.1 安全性等级
| 等级 | 威胁模型 | 实现要求 |
|---|---|---|
| 基础 | 诚实但好奇 | 加密通信 |
| 标准 | 恶意服务器 | 秘密分享 |
| 增强 | 共谋攻击 | 阈值密码学 |
| 最强 | 量子攻击 | 后量子密码学 |
8.2 效率对比
| 方案 | 通信复杂度 | 计算复杂度 | 安全性 |
|---|---|---|---|
| 无安全 | O(n) | O(n) | 无 |
| 朴素掩码 | O(n²) | O(n²) | 低 |
| 星型拓扑 | O(n) | O(n) | 中 |
| Bonawitz | O(n log n) | O(n log n) | 高 |
| HE方案 | O(n) | O(n log² n) | 高 |
9. 参考文献
10. 相关主题
- federated-learning-fundamentals — 联邦学习基础
- federated-learning-privacy-dp — 差分隐私保护
- byzantine-attacks-defense — 拜占庭攻击防御