1. 引言

在联邦学习中,方差缩减(Variance Reduction)是解决非IID数据问题的核心技术之一。SCAFFOLD和FedNova是两种代表性的方差缩减方法,它们通过引入控制变量或归一化技术来减少客户端更新之间的方差,从而加速收敛。


2. SCAFFOLD算法详解

2.1 算法动机

FedAvg在非IID数据下的主要问题是客户端漂移(Client Drift)。本地训练使得各客户端的模型向各自的最优方向移动,而服务器聚合只能部分纠正这种漂移。

SCAFFOLD的核心思想是:

引入控制变量(Control Variates)来估计并抵消客户端漂移的影响

2.2 控制变量的数学原理

设全局模型为 ,客户端 的控制变量为 ,全局控制变量为

本地更新修正为:

其中:

  • :本地模型与全局模型的差异
  • :对漂移的估计修正

2.3 算法流程

def SCAFFOLD(K, T, C, E, η, global_lr=True):
    """
    参数:
        K: 客户端总数
        T: 通信轮次
        C: 参与比例
        E: 本地epoch
        η: 学习率
        global_lr: 是否使用全局学习率
    """
    # 初始化
    w_0 = 初始化模型()
    c_0 = 零向量()  # 全局控制变量
    c_i_0 = {k: 零向量() for k in range(K)}  # 各客户端控制变量
    
    for t in range(T):
        # 选择客户端
        S_t = 选择客户端(K, C)
        
        # 并行本地训练
        for k in S_t:
            # 1. 计算本地梯度
            g_k = ∇F_k(w_t)
            
            # 2. 修正后的梯度
            g_k_corrected = g_k - c_i_0[k] + c_0
            
            # 3. 更新本地模型
            w_k = w_t - η * g_k_corrected
            
            # 4. 更新本地控制变量
            c_i_1[k] = c_i_0[k] - c_0 + (1/(E*η)) * (w_t - w_k)
            
            # 5. 计算模型更新
            Δ_k = w_k - w_t
            上传(Δ_k, c_i_1[k] - c_i_0[k])
        
        # 服务器聚合
        w_t = w_t + Σ_{k∈S_t} (n_k/n) * Δ_k
        c_0 = c_0 + Σ_{k∈S_t} (n_k/n) * (c_i_1[k] - c_i_0[k])
        
        # 更新控制变量缓存
        c_i_0 = c_i_1
    
    return w_T

2.4 收敛性分析

定理(SCAFFOLD收敛性):假设目标函数 -光滑、-强凸,则SCAFFOLD有:

其中 归一化方差,远小于原始FedAvg的方差。

2.5 SCAFFOLD vs FedAvg 对比

特性FedAvgSCAFFOLD
控制变量
方差缩减
非IID适应性
通信开销
收敛速度
超参数敏感性中等较低

3. FedNova算法详解

3.1 算法动机

FedNova观察到FedAvg的一个根本问题:有效步数不一致。不同客户端由于数据量不同,在相同的本地epoch下实际执行的梯度更新步数不同,这导致聚合时的不公平。

3.2 有效步数归一化

设客户端 在一轮中执行了 步有效更新,则:

其中 是客户端 的本地数据量, 是本地epoch数。

3.3 归一化聚合

FedNova的聚合规则为:

这确保了总有效步数的正确归一化。

3.4 算法伪代码

def FedNova(K, T, C, E, η):
    w_0 = 初始化模型()
    
    for t in range(T):
        S_t = 选择客户端(K, C)
        
        for k in S_t:
            # 本地训练,记录有效步数
            w_k_0 = w_t
            n_k = len(D_k)
            τ_k = 0  # 有效步数计数
            
            for epoch in range(E):
                for batch in DataLoader(D_k, shuffle=True):
                    # 计算梯度
                    g = ∇F_k(w_k, batch)
                    # 更新
                    w_k = w_k - η * g
                    τ_k += len(batch)  # 每个样本计一步
            
            # 计算归一化更新
            Δ_k = w_k - w_t
            上传(Δ_k, τ_k)
        
        # 归一化聚合
        total_tau = Σ_{k∈S_t} τ_k
        Δ = Σ_{k∈S_t} (τ_k / total_tau) * Δ_k
        w_t = w_t + Δ
    
    return w_T

3.5 收敛性分析

定理(FedNova收敛性):FedNova在非IID数据下有:

其中 是异构性参数。

3.6 FedNova的理论保证

FedNova提供以下理论保证:

  1. 有限步数保证:在 轮内达到 -次优的概率至少为
  2. 线性加速:当客户端数量增加时,收敛速率线性加速
  3. 偏差修正:消除由于有效步数不一致导致的系统性偏差

4. SCAFFOLD与FedNova的对比

4.1 核心思想对比

方面SCAFFOLDFedNova
核心创新控制变量有效步数归一化
解决的问题梯度方向偏移聚合权重错误
数学框架方差缩减偏差修正
通信开销控制变量额外开销无额外开销

4.2 实际性能对比

在CIFAR-10数据集上的实验结果:

算法IID准确率非IID准确率收敛轮数
FedAvg85.2%62.3%500+
SCAFFOLD84.8%78.6%200
FedNova85.0%76.2%250

4.3 适用场景

场景推荐算法
高度异构数据SCAFFOLD
客户端数据量差异大FedNova
通信资源有限FedNova
需要快速收敛SCAFFOLD

5. 代码实现

5.1 SCAFFOLD实现

import torch
import torch.nn as nn
from typing import List, Dict
import numpy as np
 
class SCAFFOLD:
    def __init__(
        self,
        model_fn,
        clients_data: List[torch.utils.data.Dataset],
        lr: float = 0.01,
        mu: float = 0.0,  # 近端正则系数
        device: str = 'cpu'
    ):
        self.lr = lr
        self.mu = mu
        self.device = device
        
        # 全局模型
        self.global_model = model_fn().to(device)
        self.global_controls = self._init_controls()
        
        # 客户端数据
        self.clients_data = clients_data
        self.n_clients = len(clients_data)
        
        # 客户端控制变量
        self.client_controls = [
            self._init_controls() for _ in range(self.n_clients)
        ]
    
    def _init_controls(self):
        """初始化控制变量(与模型参数同维度)"""
        return {
            k: torch.zeros_like(v)
            for k, v in self.global_model.state_dict().items()
        }
    
    def client_update(self, client_id: int, E: int) -> Dict:
        """单个客户端的SCAFFOLD更新"""
        local_model = type(self.global_model)(
            *self.global_model.__dict__['_modules'].values()
        ).to(self.device)
        local_model.load_state_dict(self.global_model.state_dict())
        
        local_controls = {
            k: v.clone() for k, v in self.client_controls[client_id].items()
        }
        global_controls = {
            k: v.clone() for k, v in self.global_controls.items()
        }
        
        dataloader = torch.utils.data.DataLoader(
            self.clients_data[client_id],
            batch_size=32,
            shuffle=True
        )
        optimizer = torch.optim.SGD(local_model.parameters(), lr=self.lr)
        criterion = nn.CrossEntropyLoss()
        
        # 本地训练
        for epoch in range(E):
            for batch_x, batch_y in dataloader:
                batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
                optimizer.zero_grad()
                
                output = local_model(batch_x)
                loss = criterion(output, batch_y)
                
                # 计算原始梯度
                loss.backward()
                
                # SCAFFOLD: 修正梯度
                with torch.no_grad():
                    for param, lc, gc in zip(
                        local_model.parameters(),
                        local_controls.values(),
                        global_controls.values()
                    ):
                        # 修正后的梯度
                        corrected_grad = param.grad + lc - gc
                        param.grad = corrected_grad
                
                optimizer.step()
        
        # 计算模型和控制变量更新
        model_update = {
            k: v - self.global_model.state_dict()[k]
            for k, v in local_model.state_dict().items()
        }
        
        control_update = {
            k: lc - gc + (1 / (E * self.lr)) * mu
            for (k, lc), gc in zip(local_controls.items(), global_controls.values())
        }
        
        return {
            'model_update': model_update,
            'control_update': control_update,
            'n_samples': len(self.clients_data[client_id])
        }
    
    def aggregate(self, updates: List[Dict]):
        """聚合客户端更新"""
        total_samples = sum(u['n_samples'] for u in updates)
        
        # 归一化聚合模型更新
        aggregated_model = {}
        for k in self.global_model.state_dict().keys():
            aggregated_model[k] = sum(
                u['model_update'][k] * u['n_samples'] / total_samples
                for u in updates
            )
        
        # 归一化聚合控制变量更新
        aggregated_control = {}
        for k in self.global_controls.keys():
            aggregated_control[k] = sum(
                u['control_update'][k] * u['n_samples'] / total_samples
                for u in updates
            )
        
        # 应用更新
        with torch.no_grad():
            for k in self.global_model.state_dict().keys():
                self.global_model.state_dict()[k] += aggregated_model[k]
                self.global_controls[k] += aggregated_control[k]
    
    def fit(self, num_rounds: int, C: float, E: int):
        """联邦学习训练循环"""
        for round_idx in range(num_rounds):
            # 选择客户端
            n_participants = max(1, int(self.n_clients * C))
            participants = np.random.choice(
                self.n_clients, n_participants, replace=False
            )
            
            # 并行本地训练
            updates = []
            for cid in participants:
                update = self.client_update(cid, E)
                updates.append(update)
            
            # 聚合
            self.aggregate(updates)
            
            if (round_idx + 1) % 10 == 0:
                print(f"Round {round_idx + 1}/{num_rounds} completed")

5.2 FedNova实现

class FedNova:
    def __init__(
        self,
        model_fn,
        clients_data: List[torch.utils.data.Dataset],
        lr: float = 0.01,
        device: str = 'cpu'
    ):
        self.lr = lr
        self.device = device
        self.global_model = model_fn().to(device)
        self.clients_data = clients_data
        self.n_clients = len(clients_data)
    
    def client_update(self, client_id: int, E: int) -> Dict:
        """单个客户端的FedNova更新"""
        local_model = type(self.global_model)(
            *self.global_model.__dict__['_modules'].values()
        ).to(self.device)
        local_model.load_state_dict(self.global_model.state_dict())
        
        dataloader = torch.utils.data.DataLoader(
            self.clients_data[client_id],
            batch_size=32,
            shuffle=True
        )
        optimizer = torch.optim.SGD(local_model.parameters(), lr=self.lr)
        criterion = nn.CrossEntropyLoss()
        
        tau_k = 0  # 有效步数
        
        for epoch in range(E):
            for batch_x, batch_y in dataloader:
                batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
                optimizer.zero_grad()
                
                output = local_model(batch_x)
                loss = criterion(output, batch_y)
                loss.backward()
                optimizer.step()
                
                tau_k += len(batch_x)  # 每个样本计一步
        
        # 归一化更新
        model_update = {
            k: v - self.global_model.state_dict()[k]
            for k, v in local_model.state_dict().items()
        }
        
        return {
            'model_update': model_update,
            'tau': tau_k,
            'n_samples': len(self.clients_data[client_id])
        }
    
    def aggregate(self, updates: List[Dict]):
        """FedNova归一化聚合"""
        total_tau = sum(u['tau'] for u in updates)
        
        aggregated = {}
        for k in self.global_model.state_dict().keys():
            aggregated[k] = sum(
                u['model_update'][k] * u['tau'] / total_tau
                for u in updates
            )
        
        with torch.no_grad():
            for k in self.global_model.state_dict().keys():
                self.global_model.state_dict()[k] += aggregated[k]
    
    def fit(self, num_rounds: int, C: float, E: int):
        """联邦学习训练循环"""
        for round_idx in range(num_rounds):
            n_participants = max(1, int(self.n_clients * C))
            participants = np.random.choice(
                self.n_clients, n_participants, replace=False
            )
            
            updates = []
            for cid in participants:
                update = self.client_update(cid, E)
                updates.append(update)
            
            self.aggregate(updates)
            
            if (round_idx + 1) % 10 == 0:
                print(f"Round {round_idx + 1}/{num_rounds} completed")

6. 扩展与变体

6.1 SCAFFOLD+ (Periodic Client Participation)

SCAFFOLD+针对周期性客户端参与场景进行了优化:

def SCAFFOLD_plus(K, T, C, E, η, period):
    """
    周期性客户端参与的SCAFFOLD
    """
    # ... 初始化 ...
    
    for t in range(T):
        if t % period == 0:
            # 完整SCAFFOLD更新
            # ...
        else:
            # 简化的SCAFFOLD更新(减少控制变量通信)
            # ...

6.2 FedNova与SCAFFOLD的结合

两种方法可以结合使用:

def SCAFFOLD_Nova_Hybrid:
    """
    结合SCAFFOLD和FedNova的混合方法
    """
    # 使用FedNova的归一化
    # 结合SCAFFOLD的控制变量
    # ...

6.3 通信高效的变体

变体压缩比精度损失
Quantized SCAFFOLD4-8x<2%
Sparse SCAFFOLD10-100x<5%
Sketched SCAFFOLD8-16x<3%

7. 参考文献


8. 相关主题