1. 定义与背景

1.1 联邦学习的起源

联邦学习(Federated Learning, FL)由McMahan等人在2017年提出,旨在解决数据孤岛隐私保护问题。其核心思想是:

“将模型送到数据所在的地方,而非将数据汇聚到模型所在的地方”

传统机器学习需要将分散在各地的数据收集到中央服务器进行训练,这面临:

  • 隐私泄露风险:敏感数据在传输和存储过程中可能被攻击
  • 通信带宽瓶颈:大规模数据的传输成本高昂
  • 数据主权问题:各地数据法规限制数据跨境流动

1.2 联邦学习的形式化定义

设共有 个客户端参与训练,每个客户端 拥有本地数据集 ,目标是学习全局模型参数 ,使得全局目标函数最小化:

其中 是客户端 的权重, 是客户端 的本地经验风险。

1.3 联邦学习 vs 传统分布式学习

特性传统分布式学习联邦学习
数据分布数据可集中存储数据始终留在本地
节点可信度通常可信可能存在恶意节点
通信模式高带宽、低延迟低带宽、高延迟
数据异构性通常IID通常非IID
节点规模数个-数十个数百-数百万
系统异构性环境一致设备差异大

2. 联邦学习的核心挑战

2.1 统计异构性(Statistical Heterogeneity)

非独立同分布(Non-IID)数据是联邦学习面临的核心挑战。设 为客户端 的数据分布,理想情况下 对所有 成立。但在实际场景中:

  • 特征分布偏移(Feature Distribution Skew):不同客户端的特征分布不同
  • 标签分布偏移(Label Distribution Skew):不同客户端的标签分布不同
  • 数量偏移(Quantity Skew):不同客户端的数据量差异巨大
  • 同步偏移(Temporal Skew):数据随时间变化

非IID数据会导致:

  1. 客户端漂移(Client Drift):本地模型偏离全局最优
  2. 收敛困难:全局收敛速度下降甚至发散
  3. 性能下降:最终模型在某些客户端上表现不佳

2.2 系统异构性(System Heterogeneity)

异构类型具体表现影响
计算能力不同设备的CPU/GPU性能差异本地训练时间不一致
存储能力可用内存和存储空间不同限制模型规模
通信带宽网络条件差异巨大通信成为瓶颈
连接稳定性设备可能随时离线部分设备无法参与
电池状态移动设备电池有限影响参与意愿

2.3 隐私与安全挑战

2.3.1 隐私泄露风险

即使只上传模型参数,攻击者仍可通过以下方式获取隐私信息:

  1. 梯度反演攻击(Gradient Inversion Attack)

    • 通过上传的梯度重建原始训练数据
    • Zhu等人在2019年证明可以100%恢复原始图像
  2. 成员推断攻击(Membership Inference Attack)

    • 判断某样本是否参与训练
    • 利用模型对训练数据的”记忆”
  3. 模型反演攻击(Model Inversion Attack)

    • 从模型参数恢复敏感训练数据

2.3.2 恶意攻击

  1. 拜占庭攻击(Byzantine Attack)

    • 恶意客户端上传任意错误的模型更新
    • 可导致全局模型完全失效
  2. 后门攻击(Backdoor Attack)

    • 在模型中植入隐藏的后门
    • 在特定输入下触发恶意行为
  3. 模型投毒(Model Poisoning)

    • 通过修改本地模型影响全局模型
    • 可与拜占庭攻击结合

2.4 通信效率挑战

联邦学习中通信是主要瓶颈:

  • 模型参数量可达数十亿
  • 通信轮次可能需要数千次
  • 移动设备带宽有限且不稳定

3. 联邦学习的分类体系

3.1 按数据分区方式

联邦学习
├── 水平联邦学习 (Horizontal FL)
│   ├── 各方拥有相同特征,不同样本
│   └── 场景:跨银行、跨医院的联合建模
│
├── 垂直联邦学习 (Vertical FL)
│   ├── 各方拥有相同样本,不同特征
│   └── 场景:银行+电商联合建模
│
└── 联邦迁移学习 (Federated Transfer Learning)
    ├── 各方样本和特征都部分重叠
    └── 场景:跨领域知识迁移

3.2 按网络拓扑结构

拓扑类型特点典型场景
中心化FL中心服务器协调传统FL场景
去中心化FL点对点通信分布式网络
分层FL多级聚合大规模部署
星型拓扑中心-边缘结构物联网场景

3.3 按参与方式

参与方式描述优缺点
全量参与所有客户端每轮都参与收敛稳定,但通信开销大
随机抽样每轮随机选择部分客户端通信高效,但方差大
分层抽样按能力分层,层内随机平衡效率与质量
基于重要性选择”重要”客户端可加速收敛,但需先验知识

4. 联邦学习的基本流程

4.1 标准FedAvg算法流程

def FedAvg(K, T, η):
    """
    K: 客户端数量
    T: 总通信轮次
    η: 学习率
    """
    w_0 = 初始化全局模型
    
    for t in range(T):
        # 1. 服务器选择客户端子集 S_t
        S_t = select_clients(K, fraction=C)
        
        # 2. 服务器向选中客户端分发当前模型
        for k in S_t:
            send(w_t to client k)
        
        # 3. 各客户端执行本地训练
        for k in S_t:
            w_{t+1}^k = 本地训练(w_t, D_k, η)
        
        # 4. 客户端上传本地更新
        for k in S_t:
            Δ_k = w_{t+1}^k - w_t
            upload(Δ_k to server)
        
        # 5. 服务器聚合更新
        w_{t+1} = w_t + Σ_{k∈S_t} (n_k/n) * Δ_k
    
    return w_T

4.2 单轮通信的详细过程

┌─────────────────────────────────────────────────────────────┐
│                     服务器端                                 │
│  ┌──────────┐    分发模型    ┌──────────┐                 │
│  │ w_t      │ ────────────▶ │ Client 1 │                 │
│  └──────────┘               └──────────┘                 │
│       ▲                          │                        │
│       │                          │ 本地训练                │
│       │                          ▼                        │
│       │                     ┌──────────┐                 │
│       │                     │ w_t+1^1  │                 │
│       │                     └──────────┘                 │
│       │                          │                        │
│       │                     上传更新                      │
│       │                          │                        │
│       └──────────────────────────┘                        │
│                    聚合更新                                 │
└─────────────────────────────────────────────────────────────┘

4.3 本地训练的具体实现

def local_train(w, D_k, η, E):
    """
    w: 接收到的全局模型
    D_k: 客户端k的本地数据
    η: 学习率
    E: 本地epoch数
    """
    for epoch in range(E):
        for batch in DataLoader(D_k):
            # 计算梯度
            g = ∇F_k(w, batch)
            # 更新本地模型
            w = w - η * g
    return w

5. 联邦学习的评估指标

5.1 模型性能指标

指标定义适用场景
测试准确率全局模型在测试集上的准确率分类任务
AUC-ROC曲线下面积分类任务
MSE/MAE均方/绝对误差回归任务
困惑度语言模型性能NLP任务

5.2 效率指标

指标定义优化目标
通信轮次达到目标性能所需的通信次数越少越好
通信成本每轮传输的数据量越小越好
本地计算量客户端本地计算量根据设备能力
收敛速度随轮次的性能提升越快越好

5.3 隐私与安全指标

指标定义衡量标准
隐私预算ε-δ差分隐私参数越小越隐私
攻击成功率攻击者成功窃取信息的概率越低越好
鲁棒性对恶意攻击的抵抗力越强越好

6. 联邦学习的应用场景

6.1 移动端应用

应用数据来源隐私保护需求
Gboard输入法用户打字习惯保护输入隐私
健康App用户健康数据医疗隐私法规
个性化推荐用户行为数据商业机密保护

6.2 医疗健康

  • 跨医院联合诊断:多家医院联合训练诊断模型
  • 罕见病研究:汇聚稀缺病例数据
  • 药物发现:加速新药研发

6.3 金融领域

  • 反欺诈模型:银行间共享欺诈模式
  • 信用评估:多源数据融合评估
  • 量化投资:保护投资策略

6.4 物联网

  • 智能家居:设备协同学习用户习惯
  • 自动驾驶:车队共享驾驶经验
  • 工业物联网:预测性维护

7. 与其他隐私保护技术的关系

7.1 联邦学习 vs 差分隐私

特性联邦学习差分隐私
保护对象数据不离开本地添加噪声保护单个记录
组合方式可叠加使用可在FL中使用DP
隐私保证取决于实现数学保证

7.2 联邦学习 vs 安全多方计算

特性联邦学习安全多方计算
计算方式本地计算+参数聚合密态计算
通信开销模型参数通信交互式协议
组合方式可结合MPC可在FL中使用MPC

7.3 三者结合

现代隐私保护系统通常采用多层防护:

数据隐私保护体系
├── 联邦学习:数据不出本地
├── 差分隐私:输出扰动保护
└── 安全多方计算:计算过程保护

8. 参考文献


9. 相关主题