1. 引言

在传统的联邦学习中,所有客户端共享同一个全局模型。然而,由于数据分布的异构性,一个模型难以同时满足所有客户端的需求。个性化联邦学习(Personalized Federated Learning, pFL)应运而生,旨在为每个客户端学习定制化的模型。


2. 个性化联邦学习的动机

2.1 为什么需要个性化?

考虑一个跨医院联合建模的场景:

  • 医院A:主要收治心血管疾病
  • 医院B:主要收治呼吸系统疾病
  • 医院C:主要收治神经系统疾病

使用统一的全局模型会导致:

  1. 性能次优:模型在各类别上都无法达到最优
  2. 服务不公:某些医院的特殊需求被忽视
  3. 客户端流失:服务不佳的客户端可能退出协作

2.2 个性化 vs 全局联邦学习

方面全局联邦学习个性化联邦学习
模型数量1个全局模型K个个性化模型
适用场景数据分布相似数据分布差异大
通信开销较低较高
客户端适应

3. 个性化联邦学习方法分类

个性化联邦学习方法
│
├── 基于微调的方法
│   ├── 全局模型微调
│   └── 本地适配层
│
├── 基于正则化的方法
│   ├── 知识蒸馏
│   └── 相似性正则
│
├── 基于聚类的方法
│   ├── 硬聚类
│   └── 软聚类
│
├── 基于元学习的方法
│   ├── MAML-based
│   └── Reptile-based
│
└── 基于混合架构的方法
    ├── 共享+私有参数
    └── 知识路由

4. 基于微调的个性化方法

4.1 Per-FedAvg(基于MAML)

Per-FedAvg将模型无关元学习(Model-Agnostic Meta-Learning, MAML)的思想引入联邦学习。

核心思想:学习一个好的初始化,使得每个客户端只需少量本地更新就能获得个性化模型。

目标函数

其中 是本地适应学习率。

算法流程

def PerFedAvg(K, T, C, E, η, α):
    """
    Per-FedAvg算法
    
    Args:
        K: 客户端数量
        T: 通信轮次
        C: 参与比例
        E: 本地epoch
        η: 全局学习率
        α: 本地适应学习率
    """
    θ = 初始化全局参数()
    
    for t in range(T):
        S_t = 选择客户端(K, C)
        
        for k in S_t:
            # Step 1: 本地适应(内循环)
            θ_k = θ - α * ∇F_k(θ)
            
            # Step 2: 计算元梯度
            grad_k = ∇F_k(θ_k)
            # 使用链式法则得到对θ的梯度
            # θ_k = θ - α * ∇F_k(θ)
            # ∂L/∂θ = (I - α * Hessian(F_k)) * ∂L/∂θ_k
            
            上传(grad_k, θ_k)
        
        # Step 3: 全局更新
        θ = θ - η * Σ_{k∈S_t} (n_k/n) * grad_k
    
    return θ

4.2 FedRep(表示分解)

FedRep将模型参数分解为:

  • 全局共享表示层
  • 本地私有分类器

核心思想:利用各客户端数据学习一个共享的表示,然后每个客户端拥有自己的分类头。

目标函数

交替优化

  1. 固定 ,优化各
  2. 固定 ,优化
def FedRep(K, T, C, E, η):
    θ_g = 初始化全局表示()
    θ_k = {k: 初始化私有分类器() for k in range(K)}
    
    for t in range(T):
        S_t = 选择客户端(K, C)
        
        for k in S_t:
            # 1. 固定全局表示,训练本地分类器
            for epoch in range(E):
                # 本地训练分类器
                # ...
            
            # 2. 固定分类器,训练全局表示
            for epoch in range(E):
                # 本地训练表示
                # ...
        
        # 3. 聚合全局表示
        θ_g = Σ_{k∈S_t} (n_k/n) * θ_g^k

4.3 pFedMe(个性化FedAvg)

pFedMe使用** Moreau enveloppe**来分解个性化模型和全局模型:

其中 控制个性化程度。

算法流程

def pFedMe(K, T, C, E, η, λ, β):
    θ = 初始化全局模型()
    
    for t in range(T):
        S_t = 选择客户端(K, C)
        
        for k in S_t:
            # 1. 本地优化:最小化L_k + 正则项
            θ_k = θ  # 初始化
            for epoch in range(E):
                # 梯度下降
                grad = ∇L_k(θ_k) + λ * (θ_k - θ)
                θ_k = θ_k - η * grad
            
            # 2. 计算用于聚合的更新
            Δ_k = θ_k - θ  # 个性化部分
            上传(θ_k, Δ_k)
        
        # 3. 全局聚合
        # 聚合Δ_k得到新的θ
        Δ = Σ_{k∈S_t} (n_k/n) * Δ_k
        θ = θ + β * Δ  # β是全局学习率

5. 基于聚类的个性化方法

5.1 IFCA(迭代联邦聚类)

IFCA通过迭代聚类来发现客户端的隐式分组:

def IFCA(K, T, C, E, G, η):
    """
    IFCA算法
    
    Args:
        G: 聚类数量
    """
    # 初始化G个聚类中心
    centroids = [初始化模型() for _ in range(G)]
    
    for t in range(T):
        S_t = 选择客户端(K, C)
        
        for k in S_t:
            # 1. 找到最近的聚类
            losses = [evaluate(centroids[g], D_k) for g in range(G)]
            g_k = argmin(losses)
            
            # 2. 更新本地模型
            local_model = deepcopy(centroids[g_k])
            for epoch in range(E):
                # 本地训练
                # ...
            
            # 3. 计算更新
            Δ_k = local_model - centroids[g_k]
            上传(Δ_k, g_k)
        
        # 4. 更新聚类中心
        for g in range(G):
            participants = [k for k in S_t if cluster_assignments[k] == g]
            if participants:
                centroids[g] = 更新聚类中心(participants)

5.2 FedEM(联邦集成方法)

FedEM将每个客户端建模为多个模型的混合

其中 是客户端 个模型的权重。


6. 基于知识蒸馏的方法

6.1 FedMD(知识蒸馏)

FedMD使用知识蒸馏来传递全局知识:

def FedMD(K, T, C, T_distill, η):
    """
    FedMD算法
    """
    # Step 1: 每个客户端用公共数据集初始化
    public_data = 加载公共数据集()
    
    # Step 2: 联邦训练
    for t in range(T):
        S_t = 选择客户端(K, C)
        
        # 各客户端用公共数据训练
        for k in S_t:
            # 软标签蒸馏
            teacher_probs = 全局模型预测(public_data)
            学生模型蒸馏(客户端模型, public_data, teacher_probs)
        
        # 聚合
        # ...
    
    # Step 3: 知识蒸馏阶段
    for round in range(T_distill):
        # 各客户端用公共数据执行知识蒸馏
        # ...

6.2 FedDF(联邦蒸馏)

FedDF通过互蒸馏来整合知识:

def FedDF(K, T, C):
    """
    FedDF: 使用联邦蒸馏进行知识迁移
    """
    for t in range(T):
        S_t = 选择客户端(K, C)
        
        # Step 1: 各客户端本地训练
        local_models = []
        for k in S_t:
            model_k = 本地训练(global_model, D_k)
            local_models.append(model_k)
        
        # Step 2: 互蒸馏
        # 使用集成预测作为软标签
        ensemble_predictions = 平均预测(local_models, unlabeled_data)
        
        for k in S_t:
            # 蒸馏到本地模型
            distill_loss = KL(模型_k(数据), ensemble_predictions)
            更新模型_k(蒸馏_loss)
        
        # Step 3: 聚合蒸馏后的模型
        global_model = 加权平均(local_models)

7. 代码实现

7.1 Per-FedAvg完整实现

import torch
import torch.nn as nn
from typing import List, Dict
import copy
 
class PerFedAvg:
    def __init__(
        self,
        model_fn,
        clients_data: List[torch.utils.data.Dataset],
        global_lr: float = 1.0,
        local_lr: float = 0.1,
        device: str = 'cpu'
    ):
        self.global_lr = global_lr
        self.local_lr = local_lr
        self.device = device
        self.global_model = model_fn().to(device)
        self.clients_data = clients_data
        self.n_clients = len(clients_data)
    
    def compute_meta_gradient(self, client_id: int, E: int) -> Dict:
        """
        计算Per-FedAvg的元梯度
        """
        # 复制全局模型
        theta = {
            k: v.clone() 
            for k, v in self.global_model.state_dict().items()
        }
        
        # Step 1: 计算F_k(θ)
        loss_before = self.evaluate_loss(client_id, theta)
        grad_before = torch.autograd.grad(
            loss_before, theta.values(), create_graph=True
        )
        
        # Step 2: 计算θ_k = θ - α * ∇F_k(θ)
        theta_k = {
            k: v - self.local_lr * g 
            for (k, v), g in zip(theta.items(), grad_before)
        }
        
        # Step 3: 计算F_k(θ_k)
        loss_after = self.evaluate_loss(client_id, theta_k)
        
        # Step 4: 计算元梯度 ∂F_k(θ_k)/∂θ
        meta_grads = torch.autograd.grad(
            loss_after, theta.values()
        )
        
        return {
            'meta_grads': {
                k: g.detach() for k, g in zip(theta.keys(), meta_grads)
            },
            'n_samples': len(self.clients_data[client_id])
        }
    
    def evaluate_loss(self, client_id: int, state_dict: Dict) -> torch.Tensor:
        """评估客户端损失"""
        model = copy.deepcopy(self.global_model)
        model.load_state_dict(state_dict)
        model.to(self.device)
        model.eval()
        
        dataloader = torch.utils.data.DataLoader(
            self.clients_data[client_id],
            batch_size=64
        )
        
        total_loss = 0
        n_samples = 0
        criterion = nn.CrossEntropyLoss()
        
        with torch.no_grad():
            for x, y in dataloader:
                x, y = x.to(self.device), y.to(self.device)
                output = model(x)
                loss = criterion(output, y)
                total_loss += loss.item() * len(y)
                n_samples += len(y)
        
        return total_loss / n_samples
    
    def aggregate(self, gradients: List[Dict]):
        """聚合元梯度"""
        total_samples = sum(g['n_samples'] for g in gradients)
        
        aggregated = {}
        for k in self.global_model.state_dict().keys():
            aggregated[k] = sum(
                g['meta_grads'][k] * g['n_samples'] / total_samples
                for g in gradients
            )
        
        # 应用梯度
        with torch.no_grad():
            for k in self.global_model.state_dict().keys():
                self.global_model.state_dict()[k] -= self.global_lr * aggregated[k]
    
    def get_personalized_model(self, client_id: int, E: int = 5):
        """获取客户端的个性化模型"""
        model = copy.deepcopy(self.global_model)
        model.to(self.device)
        model.train()
        
        dataloader = torch.utils.data.DataLoader(
            self.clients_data[client_id],
            batch_size=32,
            shuffle=True
        )
        optimizer = torch.optim.SGD(model.parameters(), lr=self.local_lr)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(E):
            for x, y in dataloader:
                x, y = x.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                loss = criterion(model(x), y)
                loss.backward()
                optimizer.step()
        
        return model
    
    def fit(self, num_rounds: int, C: float, E: int):
        """联邦学习训练"""
        for round_idx in range(num_rounds):
            n_participants = max(1, int(self.n_clients * C))
            participants = torch.randperm(self.n_clients)[:n_participants]
            
            gradients = []
            for cid in participants:
                grad = self.compute_meta_gradient(cid, E)
                gradients.append(grad)
            
            self.aggregate(gradients)
            
            if (round_idx + 1) % 10 == 0:
                print(f"Round {round_idx + 1}/{num_rounds}")

8. 方法对比与选择指南

8.1 方法对比

方法通信开销计算开销个性化程度适用场景
Per-FedAvg中等快速适应场景
FedRep中等中等表示学习场景
pFedMe中等中等可调通用场景
IFCA中等中等离散聚类明显场景
FedMD中等知识迁移场景

8.2 选择指南

个性化方法选择流程图

开始
  │
  ▼
数据分布是否已知?
  │
  ├─ 是 ──► 聚类方法(IFCA)
  │
  └─ 否 ──► 通信资源是否充足?
              │
              ├─ 是 ──► 知识蒸馏(FedMD)
              │
              └─ 否 ──► 需要多快适应?
                          │
                          ├─ 极快 ──► Per-FedAvg
                          │
                          └─ 一般 ──► pFedMe / FedRep

9. 参考文献


10. 相关主题