1. 引言
在传统的联邦学习中,所有客户端共享同一个全局模型。然而,由于数据分布的异构性,一个模型难以同时满足所有客户端的需求。个性化联邦学习(Personalized Federated Learning, pFL)应运而生,旨在为每个客户端学习定制化的模型。
2. 个性化联邦学习的动机
2.1 为什么需要个性化?
考虑一个跨医院联合建模的场景:
- 医院A:主要收治心血管疾病
- 医院B:主要收治呼吸系统疾病
- 医院C:主要收治神经系统疾病
使用统一的全局模型会导致:
- 性能次优:模型在各类别上都无法达到最优
- 服务不公:某些医院的特殊需求被忽视
- 客户端流失:服务不佳的客户端可能退出协作
2.2 个性化 vs 全局联邦学习
| 方面 | 全局联邦学习 | 个性化联邦学习 |
|---|---|---|
| 模型数量 | 1个全局模型 | K个个性化模型 |
| 适用场景 | 数据分布相似 | 数据分布差异大 |
| 通信开销 | 较低 | 较高 |
| 客户端适应 | 差 | 好 |
3. 个性化联邦学习方法分类
个性化联邦学习方法
│
├── 基于微调的方法
│ ├── 全局模型微调
│ └── 本地适配层
│
├── 基于正则化的方法
│ ├── 知识蒸馏
│ └── 相似性正则
│
├── 基于聚类的方法
│ ├── 硬聚类
│ └── 软聚类
│
├── 基于元学习的方法
│ ├── MAML-based
│ └── Reptile-based
│
└── 基于混合架构的方法
├── 共享+私有参数
└── 知识路由
4. 基于微调的个性化方法
4.1 Per-FedAvg(基于MAML)
Per-FedAvg将模型无关元学习(Model-Agnostic Meta-Learning, MAML)的思想引入联邦学习。
核心思想:学习一个好的初始化,使得每个客户端只需少量本地更新就能获得个性化模型。
目标函数:
其中 是本地适应学习率。
算法流程:
def PerFedAvg(K, T, C, E, η, α):
"""
Per-FedAvg算法
Args:
K: 客户端数量
T: 通信轮次
C: 参与比例
E: 本地epoch
η: 全局学习率
α: 本地适应学习率
"""
θ = 初始化全局参数()
for t in range(T):
S_t = 选择客户端(K, C)
for k in S_t:
# Step 1: 本地适应(内循环)
θ_k = θ - α * ∇F_k(θ)
# Step 2: 计算元梯度
grad_k = ∇F_k(θ_k)
# 使用链式法则得到对θ的梯度
# θ_k = θ - α * ∇F_k(θ)
# ∂L/∂θ = (I - α * Hessian(F_k)) * ∂L/∂θ_k
上传(grad_k, θ_k)
# Step 3: 全局更新
θ = θ - η * Σ_{k∈S_t} (n_k/n) * grad_k
return θ4.2 FedRep(表示分解)
FedRep将模型参数分解为:
- 全局共享表示层:
- 本地私有分类器:
核心思想:利用各客户端数据学习一个共享的表示,然后每个客户端拥有自己的分类头。
目标函数:
交替优化:
- 固定 ,优化各
- 固定 ,优化
def FedRep(K, T, C, E, η):
θ_g = 初始化全局表示()
θ_k = {k: 初始化私有分类器() for k in range(K)}
for t in range(T):
S_t = 选择客户端(K, C)
for k in S_t:
# 1. 固定全局表示,训练本地分类器
for epoch in range(E):
# 本地训练分类器
# ...
# 2. 固定分类器,训练全局表示
for epoch in range(E):
# 本地训练表示
# ...
# 3. 聚合全局表示
θ_g = Σ_{k∈S_t} (n_k/n) * θ_g^k4.3 pFedMe(个性化FedAvg)
pFedMe使用** Moreau enveloppe**来分解个性化模型和全局模型:
其中 控制个性化程度。
算法流程:
def pFedMe(K, T, C, E, η, λ, β):
θ = 初始化全局模型()
for t in range(T):
S_t = 选择客户端(K, C)
for k in S_t:
# 1. 本地优化:最小化L_k + 正则项
θ_k = θ # 初始化
for epoch in range(E):
# 梯度下降
grad = ∇L_k(θ_k) + λ * (θ_k - θ)
θ_k = θ_k - η * grad
# 2. 计算用于聚合的更新
Δ_k = θ_k - θ # 个性化部分
上传(θ_k, Δ_k)
# 3. 全局聚合
# 聚合Δ_k得到新的θ
Δ = Σ_{k∈S_t} (n_k/n) * Δ_k
θ = θ + β * Δ # β是全局学习率5. 基于聚类的个性化方法
5.1 IFCA(迭代联邦聚类)
IFCA通过迭代聚类来发现客户端的隐式分组:
def IFCA(K, T, C, E, G, η):
"""
IFCA算法
Args:
G: 聚类数量
"""
# 初始化G个聚类中心
centroids = [初始化模型() for _ in range(G)]
for t in range(T):
S_t = 选择客户端(K, C)
for k in S_t:
# 1. 找到最近的聚类
losses = [evaluate(centroids[g], D_k) for g in range(G)]
g_k = argmin(losses)
# 2. 更新本地模型
local_model = deepcopy(centroids[g_k])
for epoch in range(E):
# 本地训练
# ...
# 3. 计算更新
Δ_k = local_model - centroids[g_k]
上传(Δ_k, g_k)
# 4. 更新聚类中心
for g in range(G):
participants = [k for k in S_t if cluster_assignments[k] == g]
if participants:
centroids[g] = 更新聚类中心(participants)5.2 FedEM(联邦集成方法)
FedEM将每个客户端建模为多个模型的混合:
其中 是客户端 第 个模型的权重。
6. 基于知识蒸馏的方法
6.1 FedMD(知识蒸馏)
FedMD使用知识蒸馏来传递全局知识:
def FedMD(K, T, C, T_distill, η):
"""
FedMD算法
"""
# Step 1: 每个客户端用公共数据集初始化
public_data = 加载公共数据集()
# Step 2: 联邦训练
for t in range(T):
S_t = 选择客户端(K, C)
# 各客户端用公共数据训练
for k in S_t:
# 软标签蒸馏
teacher_probs = 全局模型预测(public_data)
学生模型蒸馏(客户端模型, public_data, teacher_probs)
# 聚合
# ...
# Step 3: 知识蒸馏阶段
for round in range(T_distill):
# 各客户端用公共数据执行知识蒸馏
# ...6.2 FedDF(联邦蒸馏)
FedDF通过互蒸馏来整合知识:
def FedDF(K, T, C):
"""
FedDF: 使用联邦蒸馏进行知识迁移
"""
for t in range(T):
S_t = 选择客户端(K, C)
# Step 1: 各客户端本地训练
local_models = []
for k in S_t:
model_k = 本地训练(global_model, D_k)
local_models.append(model_k)
# Step 2: 互蒸馏
# 使用集成预测作为软标签
ensemble_predictions = 平均预测(local_models, unlabeled_data)
for k in S_t:
# 蒸馏到本地模型
distill_loss = KL(模型_k(数据), ensemble_predictions)
更新模型_k(蒸馏_loss)
# Step 3: 聚合蒸馏后的模型
global_model = 加权平均(local_models)7. 代码实现
7.1 Per-FedAvg完整实现
import torch
import torch.nn as nn
from typing import List, Dict
import copy
class PerFedAvg:
def __init__(
self,
model_fn,
clients_data: List[torch.utils.data.Dataset],
global_lr: float = 1.0,
local_lr: float = 0.1,
device: str = 'cpu'
):
self.global_lr = global_lr
self.local_lr = local_lr
self.device = device
self.global_model = model_fn().to(device)
self.clients_data = clients_data
self.n_clients = len(clients_data)
def compute_meta_gradient(self, client_id: int, E: int) -> Dict:
"""
计算Per-FedAvg的元梯度
"""
# 复制全局模型
theta = {
k: v.clone()
for k, v in self.global_model.state_dict().items()
}
# Step 1: 计算F_k(θ)
loss_before = self.evaluate_loss(client_id, theta)
grad_before = torch.autograd.grad(
loss_before, theta.values(), create_graph=True
)
# Step 2: 计算θ_k = θ - α * ∇F_k(θ)
theta_k = {
k: v - self.local_lr * g
for (k, v), g in zip(theta.items(), grad_before)
}
# Step 3: 计算F_k(θ_k)
loss_after = self.evaluate_loss(client_id, theta_k)
# Step 4: 计算元梯度 ∂F_k(θ_k)/∂θ
meta_grads = torch.autograd.grad(
loss_after, theta.values()
)
return {
'meta_grads': {
k: g.detach() for k, g in zip(theta.keys(), meta_grads)
},
'n_samples': len(self.clients_data[client_id])
}
def evaluate_loss(self, client_id: int, state_dict: Dict) -> torch.Tensor:
"""评估客户端损失"""
model = copy.deepcopy(self.global_model)
model.load_state_dict(state_dict)
model.to(self.device)
model.eval()
dataloader = torch.utils.data.DataLoader(
self.clients_data[client_id],
batch_size=64
)
total_loss = 0
n_samples = 0
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
for x, y in dataloader:
x, y = x.to(self.device), y.to(self.device)
output = model(x)
loss = criterion(output, y)
total_loss += loss.item() * len(y)
n_samples += len(y)
return total_loss / n_samples
def aggregate(self, gradients: List[Dict]):
"""聚合元梯度"""
total_samples = sum(g['n_samples'] for g in gradients)
aggregated = {}
for k in self.global_model.state_dict().keys():
aggregated[k] = sum(
g['meta_grads'][k] * g['n_samples'] / total_samples
for g in gradients
)
# 应用梯度
with torch.no_grad():
for k in self.global_model.state_dict().keys():
self.global_model.state_dict()[k] -= self.global_lr * aggregated[k]
def get_personalized_model(self, client_id: int, E: int = 5):
"""获取客户端的个性化模型"""
model = copy.deepcopy(self.global_model)
model.to(self.device)
model.train()
dataloader = torch.utils.data.DataLoader(
self.clients_data[client_id],
batch_size=32,
shuffle=True
)
optimizer = torch.optim.SGD(model.parameters(), lr=self.local_lr)
criterion = nn.CrossEntropyLoss()
for epoch in range(E):
for x, y in dataloader:
x, y = x.to(self.device), y.to(self.device)
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
optimizer.step()
return model
def fit(self, num_rounds: int, C: float, E: int):
"""联邦学习训练"""
for round_idx in range(num_rounds):
n_participants = max(1, int(self.n_clients * C))
participants = torch.randperm(self.n_clients)[:n_participants]
gradients = []
for cid in participants:
grad = self.compute_meta_gradient(cid, E)
gradients.append(grad)
self.aggregate(gradients)
if (round_idx + 1) % 10 == 0:
print(f"Round {round_idx + 1}/{num_rounds}")8. 方法对比与选择指南
8.1 方法对比
| 方法 | 通信开销 | 计算开销 | 个性化程度 | 适用场景 |
|---|---|---|---|---|
| Per-FedAvg | 中等 | 高 | 高 | 快速适应场景 |
| FedRep | 中等 | 中等 | 高 | 表示学习场景 |
| pFedMe | 中等 | 中等 | 可调 | 通用场景 |
| IFCA | 中等 | 中等 | 离散 | 聚类明显场景 |
| FedMD | 高 | 高 | 中等 | 知识迁移场景 |
8.2 选择指南
个性化方法选择流程图
开始
│
▼
数据分布是否已知?
│
├─ 是 ──► 聚类方法(IFCA)
│
└─ 否 ──► 通信资源是否充足?
│
├─ 是 ──► 知识蒸馏(FedMD)
│
└─ 否 ──► 需要多快适应?
│
├─ 极快 ──► Per-FedAvg
│
└─ 一般 ──► pFedMe / FedRep
9. 参考文献
10. 相关主题
- federated-learning-fundamentals — 联邦学习基础
- fedavg-fedprox-algorithms — FedAvg与FedProx算法
- federated-learning-non-iid-heterogeneity — 非IID数据问题
- scaffold-fednova-algorithms — 方差缩减算法