1. 引言
在联邦学习中,方差缩减(Variance Reduction)是解决非IID数据问题的核心技术之一。SCAFFOLD和FedNova是两种代表性的方差缩减方法,它们通过引入控制变量或归一化技术来减少客户端更新之间的方差,从而加速收敛。
2. SCAFFOLD算法详解
2.1 算法动机
FedAvg在非IID数据下的主要问题是客户端漂移(Client Drift)。本地训练使得各客户端的模型向各自的最优方向移动,而服务器聚合只能部分纠正这种漂移。
SCAFFOLD的核心思想是:
引入控制变量(Control Variates)来估计并抵消客户端漂移的影响
2.2 控制变量的数学原理
设全局模型为 ,客户端 的控制变量为 ,全局控制变量为 。
本地更新修正为:
其中:
- :本地模型与全局模型的差异
- :对漂移的估计修正
2.3 算法流程
def SCAFFOLD(K, T, C, E, η, global_lr=True):
"""
参数:
K: 客户端总数
T: 通信轮次
C: 参与比例
E: 本地epoch
η: 学习率
global_lr: 是否使用全局学习率
"""
# 初始化
w_0 = 初始化模型()
c_0 = 零向量() # 全局控制变量
c_i_0 = {k: 零向量() for k in range(K)} # 各客户端控制变量
for t in range(T):
# 选择客户端
S_t = 选择客户端(K, C)
# 并行本地训练
for k in S_t:
# 1. 计算本地梯度
g_k = ∇F_k(w_t)
# 2. 修正后的梯度
g_k_corrected = g_k - c_i_0[k] + c_0
# 3. 更新本地模型
w_k = w_t - η * g_k_corrected
# 4. 更新本地控制变量
c_i_1[k] = c_i_0[k] - c_0 + (1/(E*η)) * (w_t - w_k)
# 5. 计算模型更新
Δ_k = w_k - w_t
上传(Δ_k, c_i_1[k] - c_i_0[k])
# 服务器聚合
w_t = w_t + Σ_{k∈S_t} (n_k/n) * Δ_k
c_0 = c_0 + Σ_{k∈S_t} (n_k/n) * (c_i_1[k] - c_i_0[k])
# 更新控制变量缓存
c_i_0 = c_i_1
return w_T2.4 收敛性分析
定理(SCAFFOLD收敛性):假设目标函数 -光滑、-强凸,则SCAFFOLD有:
其中 是归一化方差,远小于原始FedAvg的方差。
2.5 SCAFFOLD vs FedAvg 对比
| 特性 | FedAvg | SCAFFOLD |
|---|---|---|
| 控制变量 | 无 | 有 |
| 方差缩减 | 无 | 有 |
| 非IID适应性 | 差 | 好 |
| 通信开销 | ||
| 收敛速度 | 慢 | 快 |
| 超参数敏感性 | 中等 | 较低 |
3. FedNova算法详解
3.1 算法动机
FedNova观察到FedAvg的一个根本问题:有效步数不一致。不同客户端由于数据量不同,在相同的本地epoch下实际执行的梯度更新步数不同,这导致聚合时的不公平。
3.2 有效步数归一化
设客户端 在一轮中执行了 步有效更新,则:
其中 是客户端 的本地数据量, 是本地epoch数。
3.3 归一化聚合
FedNova的聚合规则为:
这确保了总有效步数的正确归一化。
3.4 算法伪代码
def FedNova(K, T, C, E, η):
w_0 = 初始化模型()
for t in range(T):
S_t = 选择客户端(K, C)
for k in S_t:
# 本地训练,记录有效步数
w_k_0 = w_t
n_k = len(D_k)
τ_k = 0 # 有效步数计数
for epoch in range(E):
for batch in DataLoader(D_k, shuffle=True):
# 计算梯度
g = ∇F_k(w_k, batch)
# 更新
w_k = w_k - η * g
τ_k += len(batch) # 每个样本计一步
# 计算归一化更新
Δ_k = w_k - w_t
上传(Δ_k, τ_k)
# 归一化聚合
total_tau = Σ_{k∈S_t} τ_k
Δ = Σ_{k∈S_t} (τ_k / total_tau) * Δ_k
w_t = w_t + Δ
return w_T3.5 收敛性分析
定理(FedNova收敛性):FedNova在非IID数据下有:
其中 是异构性参数。
3.6 FedNova的理论保证
FedNova提供以下理论保证:
- 有限步数保证:在 轮内达到 -次优的概率至少为
- 线性加速:当客户端数量增加时,收敛速率线性加速
- 偏差修正:消除由于有效步数不一致导致的系统性偏差
4. SCAFFOLD与FedNova的对比
4.1 核心思想对比
| 方面 | SCAFFOLD | FedNova |
|---|---|---|
| 核心创新 | 控制变量 | 有效步数归一化 |
| 解决的问题 | 梯度方向偏移 | 聚合权重错误 |
| 数学框架 | 方差缩减 | 偏差修正 |
| 通信开销 | 控制变量额外开销 | 无额外开销 |
4.2 实际性能对比
在CIFAR-10数据集上的实验结果:
| 算法 | IID准确率 | 非IID准确率 | 收敛轮数 |
|---|---|---|---|
| FedAvg | 85.2% | 62.3% | 500+ |
| SCAFFOLD | 84.8% | 78.6% | 200 |
| FedNova | 85.0% | 76.2% | 250 |
4.3 适用场景
| 场景 | 推荐算法 |
|---|---|
| 高度异构数据 | SCAFFOLD |
| 客户端数据量差异大 | FedNova |
| 通信资源有限 | FedNova |
| 需要快速收敛 | SCAFFOLD |
5. 代码实现
5.1 SCAFFOLD实现
import torch
import torch.nn as nn
from typing import List, Dict
import numpy as np
class SCAFFOLD:
def __init__(
self,
model_fn,
clients_data: List[torch.utils.data.Dataset],
lr: float = 0.01,
mu: float = 0.0, # 近端正则系数
device: str = 'cpu'
):
self.lr = lr
self.mu = mu
self.device = device
# 全局模型
self.global_model = model_fn().to(device)
self.global_controls = self._init_controls()
# 客户端数据
self.clients_data = clients_data
self.n_clients = len(clients_data)
# 客户端控制变量
self.client_controls = [
self._init_controls() for _ in range(self.n_clients)
]
def _init_controls(self):
"""初始化控制变量(与模型参数同维度)"""
return {
k: torch.zeros_like(v)
for k, v in self.global_model.state_dict().items()
}
def client_update(self, client_id: int, E: int) -> Dict:
"""单个客户端的SCAFFOLD更新"""
local_model = type(self.global_model)(
*self.global_model.__dict__['_modules'].values()
).to(self.device)
local_model.load_state_dict(self.global_model.state_dict())
local_controls = {
k: v.clone() for k, v in self.client_controls[client_id].items()
}
global_controls = {
k: v.clone() for k, v in self.global_controls.items()
}
dataloader = torch.utils.data.DataLoader(
self.clients_data[client_id],
batch_size=32,
shuffle=True
)
optimizer = torch.optim.SGD(local_model.parameters(), lr=self.lr)
criterion = nn.CrossEntropyLoss()
# 本地训练
for epoch in range(E):
for batch_x, batch_y in dataloader:
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
optimizer.zero_grad()
output = local_model(batch_x)
loss = criterion(output, batch_y)
# 计算原始梯度
loss.backward()
# SCAFFOLD: 修正梯度
with torch.no_grad():
for param, lc, gc in zip(
local_model.parameters(),
local_controls.values(),
global_controls.values()
):
# 修正后的梯度
corrected_grad = param.grad + lc - gc
param.grad = corrected_grad
optimizer.step()
# 计算模型和控制变量更新
model_update = {
k: v - self.global_model.state_dict()[k]
for k, v in local_model.state_dict().items()
}
control_update = {
k: lc - gc + (1 / (E * self.lr)) * mu
for (k, lc), gc in zip(local_controls.items(), global_controls.values())
}
return {
'model_update': model_update,
'control_update': control_update,
'n_samples': len(self.clients_data[client_id])
}
def aggregate(self, updates: List[Dict]):
"""聚合客户端更新"""
total_samples = sum(u['n_samples'] for u in updates)
# 归一化聚合模型更新
aggregated_model = {}
for k in self.global_model.state_dict().keys():
aggregated_model[k] = sum(
u['model_update'][k] * u['n_samples'] / total_samples
for u in updates
)
# 归一化聚合控制变量更新
aggregated_control = {}
for k in self.global_controls.keys():
aggregated_control[k] = sum(
u['control_update'][k] * u['n_samples'] / total_samples
for u in updates
)
# 应用更新
with torch.no_grad():
for k in self.global_model.state_dict().keys():
self.global_model.state_dict()[k] += aggregated_model[k]
self.global_controls[k] += aggregated_control[k]
def fit(self, num_rounds: int, C: float, E: int):
"""联邦学习训练循环"""
for round_idx in range(num_rounds):
# 选择客户端
n_participants = max(1, int(self.n_clients * C))
participants = np.random.choice(
self.n_clients, n_participants, replace=False
)
# 并行本地训练
updates = []
for cid in participants:
update = self.client_update(cid, E)
updates.append(update)
# 聚合
self.aggregate(updates)
if (round_idx + 1) % 10 == 0:
print(f"Round {round_idx + 1}/{num_rounds} completed")5.2 FedNova实现
class FedNova:
def __init__(
self,
model_fn,
clients_data: List[torch.utils.data.Dataset],
lr: float = 0.01,
device: str = 'cpu'
):
self.lr = lr
self.device = device
self.global_model = model_fn().to(device)
self.clients_data = clients_data
self.n_clients = len(clients_data)
def client_update(self, client_id: int, E: int) -> Dict:
"""单个客户端的FedNova更新"""
local_model = type(self.global_model)(
*self.global_model.__dict__['_modules'].values()
).to(self.device)
local_model.load_state_dict(self.global_model.state_dict())
dataloader = torch.utils.data.DataLoader(
self.clients_data[client_id],
batch_size=32,
shuffle=True
)
optimizer = torch.optim.SGD(local_model.parameters(), lr=self.lr)
criterion = nn.CrossEntropyLoss()
tau_k = 0 # 有效步数
for epoch in range(E):
for batch_x, batch_y in dataloader:
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
optimizer.zero_grad()
output = local_model(batch_x)
loss = criterion(output, batch_y)
loss.backward()
optimizer.step()
tau_k += len(batch_x) # 每个样本计一步
# 归一化更新
model_update = {
k: v - self.global_model.state_dict()[k]
for k, v in local_model.state_dict().items()
}
return {
'model_update': model_update,
'tau': tau_k,
'n_samples': len(self.clients_data[client_id])
}
def aggregate(self, updates: List[Dict]):
"""FedNova归一化聚合"""
total_tau = sum(u['tau'] for u in updates)
aggregated = {}
for k in self.global_model.state_dict().keys():
aggregated[k] = sum(
u['model_update'][k] * u['tau'] / total_tau
for u in updates
)
with torch.no_grad():
for k in self.global_model.state_dict().keys():
self.global_model.state_dict()[k] += aggregated[k]
def fit(self, num_rounds: int, C: float, E: int):
"""联邦学习训练循环"""
for round_idx in range(num_rounds):
n_participants = max(1, int(self.n_clients * C))
participants = np.random.choice(
self.n_clients, n_participants, replace=False
)
updates = []
for cid in participants:
update = self.client_update(cid, E)
updates.append(update)
self.aggregate(updates)
if (round_idx + 1) % 10 == 0:
print(f"Round {round_idx + 1}/{num_rounds} completed")6. 扩展与变体
6.1 SCAFFOLD+ (Periodic Client Participation)
SCAFFOLD+针对周期性客户端参与场景进行了优化:
def SCAFFOLD_plus(K, T, C, E, η, period):
"""
周期性客户端参与的SCAFFOLD
"""
# ... 初始化 ...
for t in range(T):
if t % period == 0:
# 完整SCAFFOLD更新
# ...
else:
# 简化的SCAFFOLD更新(减少控制变量通信)
# ...6.2 FedNova与SCAFFOLD的结合
两种方法可以结合使用:
def SCAFFOLD_Nova_Hybrid:
"""
结合SCAFFOLD和FedNova的混合方法
"""
# 使用FedNova的归一化
# 结合SCAFFOLD的控制变量
# ...6.3 通信高效的变体
| 变体 | 压缩比 | 精度损失 |
|---|---|---|
| Quantized SCAFFOLD | 4-8x | <2% |
| Sparse SCAFFOLD | 10-100x | <5% |
| Sketched SCAFFOLD | 8-16x | <3% |
7. 参考文献
8. 相关主题
- federated-learning-fundamentals — 联邦学习基础
- fedavg-fedprox-algorithms — FedAvg与FedProx算法
- federated-learning-non-iid-heterogeneity — 非IID数据问题
- personalized-federated-learning — 个性化联邦学习