1. 引言
FedAvg(Federated Averaging)是由McMahan等人于2017年提出的联邦学习基础算法,也是后续众多算法的基石。FedProx则针对非IID数据场景对FedAvg进行了改进。本节将从算法原理、收敛性分析和实现细节三个维度进行深入探讨。
2. FedAvg算法详解
2.1 算法背景与动机
传统分布式SGD需要在每轮通信中传输大量梯度数据,这在联邦场景下通信成本过高。FedAvg的核心思想是:
在通信轮次之间允许客户端执行多轮本地梯度下降,以减少通信频率
设客户端 在第 轮的本地模型为 ,全局模型为 ,则:
经过 个epoch的本地训练后,客户端上传模型更新:
服务器通过加权平均聚合更新:
其中 是第 轮参与的客户端集合, 是客户端 的数据量。
2.2 算法伪代码
def FedAvg(K, T, C, E, η):
"""
参数:
K: 客户端总数
T: 通信轮次
C: 每轮参与客户端比例 (0 < C ≤ 1)
E: 本地epoch数
η: 学习率
返回:
w_T: 最终全局模型
"""
# 初始化全局模型
w_0 = 初始化模型参数()
for t in range(T):
# Step 1: 服务器选择客户端子集
S_t = 随机选择(K * C) 个客户端
# Step 2: 并行本地训练
for k in S_t in parallel:
# 接收当前全局模型
w_t_k = w_t
# 本地训练 E 个 epoch
for e in range(E):
# 随机打乱本地数据
B = shuffle(D_k)
# mini-batch SGD
for batch in batchify(B, batch_size):
g = ∇F_k(w_t_k, batch)
w_t_k = w_t_k - η * g
# 计算模型更新
Δ_k = w_t_k - w_t
上传(Δ_k, n_k)
# Step 3: 服务器聚合
w_t = w_t + Σ_{k∈S_t} (n_k / Σ n_j) * Δ_k
return w_T2.3 关键参数分析
| 参数 | 含义 | 影响 |
|---|---|---|
| 客户端参与比例 | 越大通信越频繁,但参与多样性越好 | |
| 本地epoch数 | 越大通信越少,但可能导致客户端漂移 | |
| 学习率 | 影响本地训练步长 |
2.4 FedAvg的几何解释
从几何角度看,FedAvg在以下空间操作:
损失函数 landscape
│
│ 全局极小值
│ ★
│ ╱ ╲
│ ╱ ╲
│ ╱ ★ ╲ ★ = 本地极小值
│ ╱ ╲ (不同客户端)
│ ╱ ╲
│───★───────────────────────▶ 参数空间
│ 客户端1的极小值
本地训练使得各客户端模型向各自的本地极小值移动,而服务器聚合则将这些移动”拉回”到全局方向。
3. 收敛性分析
3.1 符号定义
设:
- :全局目标函数
- :全局最优值
- :异构性上界
3.2 IID数据下的收敛性
定理(FedAvg收敛性):假设数据IID、目标函数光滑且梯度有界,则FedAvg在步长 下有:
直观理解:
- 第一项 :标准SGD的收敛速率
- 第二项 :本地训练带来的额外误差,随 增大而增大
3.3 非IID数据下的收敛性
当数据非IID时,收敛性分析更加复杂。定义客户端漂移:
引理:本地训练 个epoch后,客户端 的漂移上界为:
服务器聚合时,这种漂移会部分抵消,但不会完全消失。
3.4 收敛速率与参数关系
| 设置 | 收敛速率 | 通信复杂度 |
|---|---|---|
| 标准SGD | 每步通信 | |
| FedAvg () | 次通信 | |
| FedAvg () | 次通信 |
关键发现:当 时,FedAvg可以达到与标准SGD相当的收敛速率。
4. FedProx算法
4.1 算法动机
FedAvg在IID数据下表现良好,但在非IID数据下会因客户端漂移导致收敛困难甚至发散。FedProx通过在目标函数中引入近端正则项来解决这一问题。
4.2 算法定义
FedProx的本地目标函数为:
其中:
- :客户端 的本地损失
- :近端正则项
- :当前全局模型
- :近端系数(超参数)
4.3 与FedAvg的关系
当 且 时,FedProx退化为标准FedSGD。
当 且 时,FedProx退化为FedAvg。
FedProx的核心改进在于:
- 近端正则化:限制本地模型偏离全局模型太远
- 自适应步长:允许客户端根据本地条件调整步长
4.4 算法伪代码
def FedProx(K, T, C, E, η, μ):
w_0 = 初始化模型参数()
for t in range(T):
S_t = 随机选择(K * C) 个客户端
for k in S_t in parallel:
# 接收全局模型
w_t_k = w_t
# 记录上次更新(用于近端正则)
w_t_k_old = w_t
for e in range(E):
# 计算梯度
g = ∇F_k(w_t_k) + μ * (w_t_k - w_t)
# 更新模型
w_t_k = w_t_k - η * g
# 上传更新
Δ_k = w_t_k - w_t
上传(Δ_k, n_k)
# 聚合
w_t = w_t + Σ_{k∈S_t} (n_k / Σ n_j) * Δ_k
return w_T4.5 收敛性分析
定理(FedProx收敛性):假设目标函数 -光滑,则FedProx有:
其中 是梯度 Lipschitz 常数, 是异构性参数。
关键结论:
- 当 时,FedProx可以达到 的收敛速率
- 近端项 控制了客户端漂移的影响
4.6 FedAvg vs FedProx 对比
| 特性 | FedAvg | FedProx |
|---|---|---|
| 近端正则 | 无 | 有 |
| 非IID适应 | 差 | 好 |
| 收敛稳定性 | 依赖数据分布 | 更稳定 |
| 超参数 |
5. 代码实现
5.1 PyTorch实现:FedAvg
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import numpy as np
from typing import List, Callable
class FederatedAveraging:
def __init__(
self,
model: nn.Module,
clients_data: List[Subset],
client_weights: List[float] = None,
device: str = 'cpu'
):
"""
初始化联邦学习框架
Args:
model: 基础模型架构
clients_data: 各客户端的数据集列表
client_weights: 各客户端的权重(默认按数据量比例)
device: 计算设备
"""
self.global_model = model.to(device)
self.clients_data = clients_data
self.device = device
# 计算客户端权重
if client_weights is None:
self.weights = [len(d) / sum(len(d) for d in clients_data)
for d in clients_data]
else:
self.weights = client_weights
def client_update(
self,
client_id: int,
local_epochs: int,
batch_size: int,
lr: float
) -> dict:
"""
单个客户端的本地训练
Args:
client_id: 客户端索引
local_epochs: 本地训练轮数
batch_size: 批大小
lr: 学习率
Returns:
包含更新量和数据量的字典
"""
# 复制全局模型到本地
local_model = type(self.global_model)(
*self.global_model.__dict__['_modules'].values()
).to(self.device)
local_model.load_state_dict(self.global_model.state_dict())
# 准备数据和优化器
dataloader = DataLoader(
self.clients_data[client_id],
batch_size=batch_size,
shuffle=True
)
optimizer = torch.optim.SGD(local_model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
# 本地训练
local_model.train()
for epoch in range(local_epochs):
for batch_x, batch_y in dataloader:
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
optimizer.zero_grad()
output = local_model(batch_x)
loss = criterion(output, batch_y)
loss.backward()
optimizer.step()
# 计算模型更新
update = {
k: v - self.global_model.state_dict()[k]
for k, v in local_model.state_dict().items()
}
return {
'update': update,
'n_samples': len(self.clients_data[client_id])
}
def aggregate(self, client_updates: List[dict]):
"""
聚合客户端更新
Args:
client_updates: 各客户端的更新列表
"""
total_samples = sum(u['n_samples'] for u in client_updates)
# 加权平均
aggregated_update = {
k: sum(
u['update'][k] * u['n_samples'] / total_samples
for u in client_updates
)
for k in self.global_model.state_dict().keys()
}
# 应用更新
with torch.no_grad():
for k in self.global_model.state_dict().keys():
self.global_model.state_dict()[k] += aggregated_update[k]
def fit(
self,
num_rounds: int,
participation_ratio: float,
local_epochs: int,
batch_size: int,
lr: float,
criterion: nn.Module = None
):
"""
联邦学习训练主循环
Args:
num_rounds: 通信轮数
participation_ratio: 每轮参与客户端比例
local_epochs: 本地epoch数
batch_size: 批大小
lr: 学习率
"""
K = len(self.clients_data)
num_participants = max(1, int(K * participation_ratio))
for round_idx in range(num_rounds):
# 随机选择客户端
participant_ids = np.random.choice(K, num_participants, replace=False)
# 并行本地训练
updates = []
for cid in participant_ids:
update = self.client_update(cid, local_epochs, batch_size, lr)
updates.append(update)
# 聚合
self.aggregate(updates)
# 打印进度
if (round_idx + 1) % 10 == 0:
print(f"Round {round_idx + 1}/{num_rounds} completed")
def evaluate(self, test_data, batch_size: int = 64):
"""
评估全局模型
"""
self.global_model.eval()
dataloader = DataLoader(test_data, batch_size=batch_size)
criterion = nn.CrossEntropyLoss()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_x, batch_y in dataloader:
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
output = self.global_model(batch_x)
loss = criterion(output, batch_y)
total_loss += loss.item() * len(batch_y)
pred = output.argmax(dim=1)
correct += (pred == batch_y).sum().item()
total += len(batch_y)
return {
'loss': total_loss / total,
'accuracy': correct / total
}5.2 PyTorch实现:FedProx
class FederatedProx(FederatedAveraging):
def __init__(self, *args, mu: float = 1.0, **kwargs):
super().__init__(*args, **kwargs)
self.mu = mu # 近端正则系数
def client_update(
self,
client_id: int,
local_epochs: int,
batch_size: int,
lr: float
) -> dict:
"""
FedProx的本地训练:包含近端正则项
"""
# 复制全局模型
local_model = type(self.global_model)(
*self.global_model.__dict__['_modules'].values()
).to(self.device)
local_model.load_state_dict(self.global_model.state_dict())
# 保存当前全局模型作为近端参考
global_model_snapshot = {
k: v.clone() for k, v in self.global_model.state_dict().items()
}
# 准备数据
dataloader = DataLoader(
self.clients_data[client_id],
batch_size=batch_size,
shuffle=True
)
optimizer = torch.optim.SGD(local_model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
# 本地训练
local_model.train()
for epoch in range(local_epochs):
for batch_x, batch_y in dataloader:
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
optimizer.zero_grad()
# 前向传播
output = local_model(batch_x)
loss = criterion(output, batch_y)
# 近端正则项
prox_loss = 0
for (name, param), (global_name, global_param) in zip(
local_model.named_parameters(),
global_model_snapshot.items()
):
prox_loss += torch.sum(
(param - global_param) ** 2
)
prox_loss = (self.mu / 2) * prox_loss
# 总损失
total_loss = loss + prox_loss
total_loss.backward()
optimizer.step()
# 计算更新
update = {
k: v - self.global_model.state_dict()[k]
for k, v in local_model.state_dict().items()
}
return {
'update': update,
'n_samples': len(self.clients_data[client_id])
}5.3 使用示例
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
# 生成模拟数据(模拟非IID场景)
def create_non_iid_data(num_clients=10, samples_per_client=500):
"""创建模拟的非IID数据"""
X_list, y_list = [], []
for i in range(num_clients):
# 每个客户端生成不同分布的数据
X_client, y_client = make_classification(
n_samples=samples_per_client,
n_features=20,
n_informative=15,
n_classes=2,
class_sep=0.5 + 0.3 * (i % 3), # 不同客户端的分离度不同
random_state=i
)
X_list.append(TensorDataset(
torch.FloatTensor(X_client),
torch.LongTensor(y_client)
))
return X_list
# 训练示例
def main():
# 创建数据和模型
clients_data = create_non_iid_data(num_clients=10)
model = nn.Sequential(
nn.Linear(20, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 2)
)
# 初始化联邦学习
fl = FederatedProx(
model=model,
clients_data=clients_data,
mu=0.1, # 近端正则系数
device='cuda' if torch.cuda.is_available() else 'cpu'
)
# 联邦训练
fl.fit(
num_rounds=100,
participation_ratio=0.5,
local_epochs=5,
batch_size=32,
lr=0.01
)
# 评估
# ... (省略测试数据创建)
# results = fl.evaluate(test_data)
# print(f"Test Accuracy: {results['accuracy']:.4f}")
if __name__ == "__main__":
main()6. 实践注意事项
6.1 超参数调优
| 参数 | 推荐范围 | 调优建议 |
|---|---|---|
| 0.01 - 0.2 | 取决于客户端总数 | |
| 1 - 20 | 非IID时建议较小值 | |
| 0.01 - 0.1 | 通常小于标准SGD | |
| (FedProx) | 0.01 - 1.0 | 与异构程度相关 |
6.2 收敛诊断
监控以下指标判断训练是否正常:
- 客户端更新范数:过大说明不稳定
- 客户端间差异:过大说明异构性严重
- 全局损失下降:监控是否收敛
6.3 常见问题与解决方案
| 问题 | 原因 | 解决方案 |
|---|---|---|
| 收敛慢 | 本地epoch过多 | 减少 |
| 客户端漂移 | 非IID严重 | 使用FedProx |
| 通信瓶颈 | 模型太大 | 模型压缩、稀疏化 |
| 客户端掉线 | 网络不稳定 | 容错机制设计 |
7. 参考文献
8. 相关主题
- federated-learning-fundamentals — 联邦学习基础
- federated-learning-non-iid-heterogeneity — 非IID数据问题
- scaffold-fednova-algorithms — SCAFFOLD与FedNova算法
- personalized-federated-learning — 个性化联邦学习