1. 引言

联邦学习通过将数据保留在本地来保护隐私,但模型更新本身仍可能泄露信息。安全聚合(Secure Aggregation)旨在确保服务器只能看到聚合后的结果,而无法获知任何单个客户端的更新内容。


2. 安全聚合的威胁模型

2.1 威胁类型

威胁攻击者能力目标
诚实但好奇(Honest-but-Curious)遵守协议但试图从消息中推断信息服务器
恶意服务器偏离协议,收集额外信息服务器
窃听者截获通信内容通信网络
共谋客户端多个客户端联合推断其他客户端信息客户端联盟

2.2 安全目标

  1. 机密性:服务器无法获知单个客户端的更新
  2. 完整性:服务器无法篡改聚合结果
  3. 正确性:聚合结果等于真实值的某种函数

3. 密码学基础

3.1 秘密分享(Secret Sharing)

Shamir秘密分享:将秘密 分成 ,满足:

  • 任意 份可以恢复
  • 任意少于 份无法获得 的任何信息

多项式构造

其中 是随机系数。

3.2 不经意传输(Oblivious Transfer)

2选1不经意传输

  • 发送者有两个消息
  • 接收者选择
  • 接收者获得 ,但发送者不知道
  • 接收者无法获得

3.3 同态加密(Homomorphic Encryption)

加法同态:给定 ,可以计算 而不解密。

常见方案:

  • Paillier加密
  • RSA加密
  • 格基加密(更高效)

4. 安全聚合协议

4.1 基本安全聚合

def secure_aggregation_basic(updates, thresholds):
    """
    基本安全聚合协议
    
    Args:
        updates: 客户端模型更新 {(client_id, update)}
        thresholds: 每个客户端的阈值
    """
    n = len(updates)  # 客户端数量
    p = 素数模数
    
    # Step 1: 生成随机掩码
    client_masks = {}
    for i, (cid, _) in enumerate(updates):
        # 客户端i生成随机掩码
        mask =随机生成(len(update_bits))
        client_masks[cid] = mask
    
    # Step 2: 交换掩码(使用秘密分享)
    received_masks = secure_mask_exchange(client_masks, n)
    
    # Step 3: 添加掩码后的更新
    masked_updates = {}
    for cid, update in updates:
        total_mask = sum(
            received_masks[other_cid][cid] 
            for other_cid in received_masks
        ) % p
        masked_updates[cid] = (update + total_mask) % p
    
    # Step 4: 服务器聚合掩码更新
    aggregated = sum(masked_updates.values()) % p
    
    # Step 5: 客户端协作揭示掩码和
    # 只有当至少threshold个客户端参与时才能揭示
    if 参与客户端数 >= 阈值:
        mask_sum = reveal_mask_sum(received_masks)
        final_aggregation = (aggregated - mask_sum) % p
    else:
        final_aggregation = aggregated  # 使用掩码聚合
    
    return final_aggregation

4.2 Bonawitz等人的安全聚合协议

Bonawitz等人(2017)提出的协议是实际系统中最常用的方案:

class BonawitzSecureAggregation:
    def __init__(self, threshold, n_clients):
        self.threshold = threshold
        self.n_clients = n_clients
        self.secret_shares = {}
    
    def setup_phase(self, client_ids):
        """
        设置阶段:建立通信拓扑
        """
        # 为每个客户端创建秘密分享
        for cid in client_ids:
            # 生成随机密钥
            key = generate_random_key()
            
            # 将密钥分成n份
            shares = shamir_split(key, self.n_clients, self.threshold)
            self.secret_shares[cid] = shares
    
    def masking_phase(self, client_id, model_update, peers):
        """
        掩码阶段:添加随机掩码
        """
        # 生成一次性掩码
        mask = torch.randn_like(model_update)
        
        # 与对等方交换掩码份额
        received_shares = []
        for peer_id in peers:
            # 使用不经意传输交换掩码
            share = oblivious_transfer(mask, peer_id)
            received_shares.append(share)
        
        # 计算总掩码
        total_mask = sum(received_shares)
        
        # 返回掩码后的更新
        masked_update = model_update + total_mask
        
        return masked_update, mask
    
    def aggregation_phase(self, masked_updates):
        """
        聚合阶段:服务器聚合
        """
        return sum(masked_updates) / len(masked_updates)
    
    def reveal_phase(self, client_ids, masks):
        """
        揭示阶段:协作揭示掩码和
        """
        # 收集有效客户端的掩码
        total_mask = sum(masks[cid] for cid in client_ids)
        
        # 服务器减去掩码和
        return total_mask

5. 高效安全聚合

5.1 通信效率优化

优化技术效果实现复杂度
预计算减少在线通信
批处理减少消息数量
压缩减少传输数据量
稀疏化只传输重要参数

5.2 预计算掩码

class PrecomputedSecureAggregation:
    def __init__(self):
        self.precomputed_masks = {}
    
    def precompute_masks(self, client_id, n_peers, mask_size):
        """
        离线阶段:预计算大量掩码
        """
        masks = []
        for _ in range(n_peers):
            # 预计算随机掩码
            mask = torch.randn(mask_size)
            masks.append(mask)
        
        self.precomputed_masks[client_id] = masks
    
    def online_mixing(self, model_update, client_id):
        """
        在线阶段:使用预计算的掩码
        """
        masks = self.precomputed_masks[client_id]
        # XOR或加法组合掩码
        combined_mask = sum(masks) % (2**32)
        return model_update + combined_mask

5.3 星型拓扑优化

def star_topology_aggregation(updates, server_id, clients):
    """
    星型拓扑:减少通信复杂度
    
    复杂度从 O(n²) 降到 O(n)
    """
    n = len(clients)
    
    # Step 1: 客户端到服务器的掩码交换
    for client in clients:
        # 客户端生成随机数
        client.random_value = generate_random()
        
        # 客户端发送到服务器
        send(client.random_value, server_id)
    
    # Step 2: 服务器计算掩码和
    mask_sum = sum(client.random_value for client in clients)
    
    # Step 3: 服务器广播掩码和
    broadcast(mask_sum, clients)
    
    # Step 4: 各客户端计算最终掩码
    for client in clients:
        final_mask = mask_sum - client.random_value
        client.masked_update = client.update + final_mask
        send(client.masked_update, server_id)
    
    # Step 5: 服务器聚合
    return sum(client.masked_update for client in clients)

6. 与差分隐私的结合

6.1 先聚合后扰动

def secure_aggregate_with_dp(updates, privacy_config):
    """
    安全聚合 + 差分隐私
    """
    # Step 1: 安全聚合得到真实和
    true_sum = secure_aggregation(updates)
    
    # Step 2: 添加DP噪声
    sensitivity = 计算敏感性(updates)
    sigma = sensitivity * np.sqrt(2 * np.log(1.25 / privacy_config['delta'])) / privacy_config['epsilon']
    noisy_sum = true_sum + torch.randn_like(true_sum) * sigma
    
    return noisy_sum

6.2 先扰动后聚合

def dp_then_secure_aggregate(updates, privacy_config):
    """
    差分隐私 + 安全聚合
    
    每个客户端先扰动自己的更新,再进行安全聚合
    """
    # Step 1: 各客户端添加本地DP噪声
    noisy_updates = []
    for update in updates:
        sensitivity = 计算敏感性(update)
        sigma = sensitivity * np.sqrt(2 * np.log(1.25 / privacy_config['delta'])) / privacy_config['epsilon']
        noisy_update = update + torch.randn_like(update) * sigma
        noisy_updates.append(noisy_update)
    
    # Step 2: 安全聚合(不需要添加额外噪声)
    return secure_aggregation(noisy_updates)

7. 实际系统实现

7.1 Google的安全聚合实现

Google在TensorFlow Federated中实现了安全聚合:

import tensorflow_federated as tff
 
@tff.federated_computation(
    tff.type_at_clients(tff.SequenceType(tf.float32)),
    tff.type_at_server(tf.float32)  # 阈值
)
def secure_mean(client_values, threshold):
    """
    安全均值计算
    """
    return tff.aggregators_secure_sum(
        client_values,
        noise_multiplier=0.1,
        expected_clients_per_round=100,
        bits_precision=16
    )

7.2 PySyft实现

import syft as sy
from syft.lib.python import Integer
 
class SecureFederatedLearning:
    def __init__(self, threshold):
        self.threshold = threshold
        self.hook = sy.TorchHook(torch)
    
    def secure_aggregate(self, workers, model_updates):
        """
        使用PySyft进行安全聚合
        """
        # Step 1: 指针化模型更新
        pointers = []
        for worker, update in zip(workers, model_updates):
            ptr = update.send(worker)
            pointers.append(ptr)
        
        # Step 2: 使用安全聚合器
        secure_sum = sy.FederatedSumMediator(
            secure=True,
            mechanism='secure_sum',
            threshold=self.threshold
        )
        
        # Step 3: 执行安全聚合
        aggregated = secure_sum.execute(pointers)
        
        return aggregated.get()

8. 安全与效率权衡

8.1 安全性等级

等级威胁模型实现要求
基础诚实但好奇加密通信
标准恶意服务器秘密分享
增强共谋攻击阈值密码学
最强量子攻击后量子密码学

8.2 效率对比

方案通信复杂度计算复杂度安全性
无安全O(n)O(n)
朴素掩码O(n²)O(n²)
星型拓扑O(n)O(n)
BonawitzO(n log n)O(n log n)
HE方案O(n)O(n log² n)

9. 参考文献


10. 相关主题