1. 引言

FedAvg(Federated Averaging)是由McMahan等人于2017年提出的联邦学习基础算法,也是后续众多算法的基石。FedProx则针对非IID数据场景对FedAvg进行了改进。本节将从算法原理、收敛性分析和实现细节三个维度进行深入探讨。


2. FedAvg算法详解

2.1 算法背景与动机

传统分布式SGD需要在每轮通信中传输大量梯度数据,这在联邦场景下通信成本过高。FedAvg的核心思想是:

在通信轮次之间允许客户端执行多轮本地梯度下降,以减少通信频率

设客户端 在第 轮的本地模型为 ,全局模型为 ,则:

经过 个epoch的本地训练后,客户端上传模型更新:

服务器通过加权平均聚合更新:

其中 是第 轮参与的客户端集合, 是客户端 的数据量。

2.2 算法伪代码

def FedAvg(K, T, C, E, η):
    """
    参数:
        K: 客户端总数
        T: 通信轮次
        C: 每轮参与客户端比例 (0 < C ≤ 1)
        E: 本地epoch数
        η: 学习率
    
    返回:
        w_T: 最终全局模型
    """
    # 初始化全局模型
    w_0 = 初始化模型参数()
    
    for t in range(T):
        # Step 1: 服务器选择客户端子集
        S_t = 随机选择(K * C) 个客户端
        
        # Step 2: 并行本地训练
        for k in S_t in parallel:
            # 接收当前全局模型
            w_t_k = w_t
            
            # 本地训练 E 个 epoch
            for e in range(E):
                # 随机打乱本地数据
                B = shuffle(D_k)
                # mini-batch SGD
                for batch in batchify(B, batch_size):
                    g = ∇F_k(w_t_k, batch)
                    w_t_k = w_t_k - η * g
            
            # 计算模型更新
            Δ_k = w_t_k - w_t
            上传(Δ_k, n_k)
        
        # Step 3: 服务器聚合
        w_t = w_t + Σ_{k∈S_t} (n_k / Σ n_j) * Δ_k
    
    return w_T

2.3 关键参数分析

参数含义影响
客户端参与比例越大通信越频繁,但参与多样性越好
本地epoch数越大通信越少,但可能导致客户端漂移
学习率影响本地训练步长

2.4 FedAvg的几何解释

从几何角度看,FedAvg在以下空间操作:

损失函数 landscape
    │
    │          全局极小值
    │              ★
    │            ╱  ╲
    │          ╱      ╲
    │        ╱    ★    ╲      ★ = 本地极小值
    │      ╱              ╲   (不同客户端)
    │    ╱                  ╲
    │───★───────────────────────▶ 参数空间
    │  客户端1的极小值

本地训练使得各客户端模型向各自的本地极小值移动,而服务器聚合则将这些移动”拉回”到全局方向。


3. 收敛性分析

3.1 符号定义

设:

  • :全局目标函数
  • :全局最优值
  • :异构性上界

3.2 IID数据下的收敛性

定理(FedAvg收敛性):假设数据IID、目标函数光滑且梯度有界,则FedAvg在步长 下有:

直观理解

  • 第一项 :标准SGD的收敛速率
  • 第二项 :本地训练带来的额外误差,随 增大而增大

3.3 非IID数据下的收敛性

当数据非IID时,收敛性分析更加复杂。定义客户端漂移

引理:本地训练 个epoch后,客户端 的漂移上界为:

服务器聚合时,这种漂移会部分抵消,但不会完全消失。

3.4 收敛速率与参数关系

设置收敛速率通信复杂度
标准SGD每步通信
FedAvg () 次通信
FedAvg () 次通信

关键发现:当 时,FedAvg可以达到与标准SGD相当的收敛速率。


4. FedProx算法

4.1 算法动机

FedAvg在IID数据下表现良好,但在非IID数据下会因客户端漂移导致收敛困难甚至发散。FedProx通过在目标函数中引入近端正则项来解决这一问题。

4.2 算法定义

FedProx的本地目标函数为:

其中:

  • :客户端 的本地损失
  • :近端正则项
  • :当前全局模型
  • :近端系数(超参数)

4.3 与FedAvg的关系

时,FedProx退化为标准FedSGD。

时,FedProx退化为FedAvg。

FedProx的核心改进在于:

  1. 近端正则化:限制本地模型偏离全局模型太远
  2. 自适应步长:允许客户端根据本地条件调整步长

4.4 算法伪代码

def FedProx(K, T, C, E, η, μ):
    w_0 = 初始化模型参数()
    
    for t in range(T):
        S_t = 随机选择(K * C) 个客户端
        
        for k in S_t in parallel:
            # 接收全局模型
            w_t_k = w_t
            # 记录上次更新(用于近端正则)
            w_t_k_old = w_t
            
            for e in range(E):
                # 计算梯度
                g = ∇F_k(w_t_k) + μ * (w_t_k - w_t)
                # 更新模型
                w_t_k = w_t_k - η * g
            
            # 上传更新
            Δ_k = w_t_k - w_t
            上传(Δ_k, n_k)
        
        # 聚合
        w_t = w_t + Σ_{k∈S_t} (n_k / Σ n_j) * Δ_k
    
    return w_T

4.5 收敛性分析

定理(FedProx收敛性):假设目标函数 -光滑,则FedProx有:

其中 是梯度 Lipschitz 常数, 是异构性参数。

关键结论

  • 时,FedProx可以达到 的收敛速率
  • 近端项 控制了客户端漂移的影响

4.6 FedAvg vs FedProx 对比

特性FedAvgFedProx
近端正则
非IID适应
收敛稳定性依赖数据分布更稳定
超参数

5. 代码实现

5.1 PyTorch实现:FedAvg

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import numpy as np
from typing import List, Callable
 
class FederatedAveraging:
    def __init__(
        self,
        model: nn.Module,
        clients_data: List[Subset],
        client_weights: List[float] = None,
        device: str = 'cpu'
    ):
        """
        初始化联邦学习框架
        
        Args:
            model: 基础模型架构
            clients_data: 各客户端的数据集列表
            client_weights: 各客户端的权重(默认按数据量比例)
            device: 计算设备
        """
        self.global_model = model.to(device)
        self.clients_data = clients_data
        self.device = device
        
        # 计算客户端权重
        if client_weights is None:
            self.weights = [len(d) / sum(len(d) for d in clients_data) 
                          for d in clients_data]
        else:
            self.weights = client_weights
    
    def client_update(
        self,
        client_id: int,
        local_epochs: int,
        batch_size: int,
        lr: float
    ) -> dict:
        """
        单个客户端的本地训练
        
        Args:
            client_id: 客户端索引
            local_epochs: 本地训练轮数
            batch_size: 批大小
            lr: 学习率
        
        Returns:
            包含更新量和数据量的字典
        """
        # 复制全局模型到本地
        local_model = type(self.global_model)(
            *self.global_model.__dict__['_modules'].values()
        ).to(self.device)
        local_model.load_state_dict(self.global_model.state_dict())
        
        # 准备数据和优化器
        dataloader = DataLoader(
            self.clients_data[client_id],
            batch_size=batch_size,
            shuffle=True
        )
        optimizer = torch.optim.SGD(local_model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        
        # 本地训练
        local_model.train()
        for epoch in range(local_epochs):
            for batch_x, batch_y in dataloader:
                batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
                optimizer.zero_grad()
                output = local_model(batch_x)
                loss = criterion(output, batch_y)
                loss.backward()
                optimizer.step()
        
        # 计算模型更新
        update = {
            k: v - self.global_model.state_dict()[k] 
            for k, v in local_model.state_dict().items()
        }
        
        return {
            'update': update,
            'n_samples': len(self.clients_data[client_id])
        }
    
    def aggregate(self, client_updates: List[dict]):
        """
        聚合客户端更新
        
        Args:
            client_updates: 各客户端的更新列表
        """
        total_samples = sum(u['n_samples'] for u in client_updates)
        
        # 加权平均
        aggregated_update = {
            k: sum(
                u['update'][k] * u['n_samples'] / total_samples
                for u in client_updates
            )
            for k in self.global_model.state_dict().keys()
        }
        
        # 应用更新
        with torch.no_grad():
            for k in self.global_model.state_dict().keys():
                self.global_model.state_dict()[k] += aggregated_update[k]
    
    def fit(
        self,
        num_rounds: int,
        participation_ratio: float,
        local_epochs: int,
        batch_size: int,
        lr: float,
        criterion: nn.Module = None
    ):
        """
        联邦学习训练主循环
        
        Args:
            num_rounds: 通信轮数
            participation_ratio: 每轮参与客户端比例
            local_epochs: 本地epoch数
            batch_size: 批大小
            lr: 学习率
        """
        K = len(self.clients_data)
        num_participants = max(1, int(K * participation_ratio))
        
        for round_idx in range(num_rounds):
            # 随机选择客户端
            participant_ids = np.random.choice(K, num_participants, replace=False)
            
            # 并行本地训练
            updates = []
            for cid in participant_ids:
                update = self.client_update(cid, local_epochs, batch_size, lr)
                updates.append(update)
            
            # 聚合
            self.aggregate(updates)
            
            # 打印进度
            if (round_idx + 1) % 10 == 0:
                print(f"Round {round_idx + 1}/{num_rounds} completed")
    
    def evaluate(self, test_data, batch_size: int = 64):
        """
        评估全局模型
        """
        self.global_model.eval()
        dataloader = DataLoader(test_data, batch_size=batch_size)
        criterion = nn.CrossEntropyLoss()
        
        total_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch_x, batch_y in dataloader:
                batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
                output = self.global_model(batch_x)
                loss = criterion(output, batch_y)
                total_loss += loss.item() * len(batch_y)
                pred = output.argmax(dim=1)
                correct += (pred == batch_y).sum().item()
                total += len(batch_y)
        
        return {
            'loss': total_loss / total,
            'accuracy': correct / total
        }

5.2 PyTorch实现:FedProx

class FederatedProx(FederatedAveraging):
    def __init__(self, *args, mu: float = 1.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.mu = mu  # 近端正则系数
    
    def client_update(
        self,
        client_id: int,
        local_epochs: int,
        batch_size: int,
        lr: float
    ) -> dict:
        """
        FedProx的本地训练:包含近端正则项
        """
        # 复制全局模型
        local_model = type(self.global_model)(
            *self.global_model.__dict__['_modules'].values()
        ).to(self.device)
        local_model.load_state_dict(self.global_model.state_dict())
        
        # 保存当前全局模型作为近端参考
        global_model_snapshot = {
            k: v.clone() for k, v in self.global_model.state_dict().items()
        }
        
        # 准备数据
        dataloader = DataLoader(
            self.clients_data[client_id],
            batch_size=batch_size,
            shuffle=True
        )
        optimizer = torch.optim.SGD(local_model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        
        # 本地训练
        local_model.train()
        for epoch in range(local_epochs):
            for batch_x, batch_y in dataloader:
                batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
                optimizer.zero_grad()
                
                # 前向传播
                output = local_model(batch_x)
                loss = criterion(output, batch_y)
                
                # 近端正则项
                prox_loss = 0
                for (name, param), (global_name, global_param) in zip(
                    local_model.named_parameters(),
                    global_model_snapshot.items()
                ):
                    prox_loss += torch.sum(
                        (param - global_param) ** 2
                    )
                prox_loss = (self.mu / 2) * prox_loss
                
                # 总损失
                total_loss = loss + prox_loss
                
                total_loss.backward()
                optimizer.step()
        
        # 计算更新
        update = {
            k: v - self.global_model.state_dict()[k]
            for k, v in local_model.state_dict().items()
        }
        
        return {
            'update': update,
            'n_samples': len(self.clients_data[client_id])
        }

5.3 使用示例

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
 
# 生成模拟数据(模拟非IID场景)
def create_non_iid_data(num_clients=10, samples_per_client=500):
    """创建模拟的非IID数据"""
    X_list, y_list = [], []
    
    for i in range(num_clients):
        # 每个客户端生成不同分布的数据
        X_client, y_client = make_classification(
            n_samples=samples_per_client,
            n_features=20,
            n_informative=15,
            n_classes=2,
            class_sep=0.5 + 0.3 * (i % 3),  # 不同客户端的分离度不同
            random_state=i
        )
        X_list.append(TensorDataset(
            torch.FloatTensor(X_client),
            torch.LongTensor(y_client)
        ))
    
    return X_list
 
# 训练示例
def main():
    # 创建数据和模型
    clients_data = create_non_iid_data(num_clients=10)
    
    model = nn.Sequential(
        nn.Linear(20, 64),
        nn.ReLU(),
        nn.Linear(64, 32),
        nn.ReLU(),
        nn.Linear(32, 2)
    )
    
    # 初始化联邦学习
    fl = FederatedProx(
        model=model,
        clients_data=clients_data,
        mu=0.1,  # 近端正则系数
        device='cuda' if torch.cuda.is_available() else 'cpu'
    )
    
    # 联邦训练
    fl.fit(
        num_rounds=100,
        participation_ratio=0.5,
        local_epochs=5,
        batch_size=32,
        lr=0.01
    )
    
    # 评估
    # ... (省略测试数据创建)
    # results = fl.evaluate(test_data)
    # print(f"Test Accuracy: {results['accuracy']:.4f}")
 
if __name__ == "__main__":
    main()

6. 实践注意事项

6.1 超参数调优

参数推荐范围调优建议
0.01 - 0.2取决于客户端总数
1 - 20非IID时建议较小值
0.01 - 0.1通常小于标准SGD
(FedProx)0.01 - 1.0与异构程度相关

6.2 收敛诊断

监控以下指标判断训练是否正常:

  1. 客户端更新范数:过大说明不稳定
  2. 客户端间差异:过大说明异构性严重
  3. 全局损失下降:监控是否收敛

6.3 常见问题与解决方案

问题原因解决方案
收敛慢本地epoch过多减少
客户端漂移非IID严重使用FedProx
通信瓶颈模型太大模型压缩、稀疏化
客户端掉线网络不稳定容错机制设计

7. 参考文献


8. 相关主题