1. 研究背景

1.1 模型融合的挑战

大语言模型的融合面临独特挑战1

  • 参数规模大:数十亿参数难以直接操作
  • 任务多样性:不同任务可能有冲突
  • 对齐要求:需要保持预训练的对齐特性

1.2 现有方法的局限

方法对齐保持任务性能
平均合并
TaskVectors
Fisher合并
AlignMerge

2. AlignMerge核心思想

2.1 核心洞察

AlignMerge的核心发现1

Fisher信息可以引导模型融合,使融合后的模型保持与预训练模型的对齐关系

2.2 技术框架

┌─────────────────────────────────────────────────────────────────────────┐
│                           AlignMerge 框架                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  输入:                                                                  │
│    - 预训练模型 θ₀                                                     │
│    - 多个专家模型 θ₁, θ₂, ..., θₖ                                     │
│    - 各模型的Fisher信息 F₁, F₂, ..., Fₖ                               │
│                                                                          │
│    │                                                                     │
│    ▼                                                                     │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                 Fisher引导的空间变换                               │    │
│  │                                                                 │    │
│  │   对齐保持: 确保 (θᵢ - θ₀) 方向与 Fisher 方向一致              │    │
│  │                                                                 │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│    │                                                                     │
│    ▼                                                                     │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                     几何约束融合                                    │    │
│  │                                                                 │    │
│  │   min || θ - θ₀ ||²_F  s.t. θ与各专家保持对齐                    │    │
│  │                                                                 │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│    │                                                                     │
│    ▼                                                                     │
│  输出: θ*                                                             │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

3. 技术细节

3.1 Fisher信息计算

class FisherInformationCalculator:
    """
    Fisher信息计算器
    """
    def __init__(self, model, dataloader):
        self.model = model
        self.dataloader = dataloader
        
    def compute_fisher(self, num_samples=1000):
        """
        计算Fisher信息矩阵
        
        Fisher信息定义:
        F = E[∇log p(y|x,θ) ∇log p(y|x,θ)^T]
        """
        fisher = {}
        sample_count = 0
        
        self.model.eval()
        for name, param in self.model.named_parameters():
            fisher[name] = torch.zeros_like(param)
        
        for batch in self.dataloader:
            if sample_count >= num_samples:
                break
                
            # 前向传播
            outputs = self.model(batch)
            
            # 获取对数似然梯度
            log_probs = F.log_softmax(outputs, dim=-1)
            for y in batch['labels']:
                log_prob = log_probs[range(len(y)), y].sum()
                
                # 反向传播
                self.model.zero_grad()
                log_prob.backward()
                
                # 累积Fisher信息
                for name, param in self.model.named_parameters():
                    if param.grad is not None:
                        fisher[name] += param.grad.data ** 2
                
                sample_count += 1
        
        # 归一化
        for name in fisher:
            fisher[name] /= sample_count
            
        return fisher

3.2 Fisher引导的融合

def align_merge(pre_trained, experts, fishers):
    """
    AlignMerge: Fisher引导的模型融合
    
    Args:
        pre_trained: 预训练模型状态
        experts: 专家模型状态列表
        fishers: Fisher信息列表
    """
    # 1. 计算方向
    directions = []
    for expert, fisher in zip(experts, fishers):
        direction = {}
        for name in pre_trained.keys():
            # 增量
            delta = expert[name] - pre_trained[name]
            
            # Fisher归一化
            fisher_norm = fisher[name].sqrt().clamp(min=1e-8)
            delta_normalized = delta / fisher_norm
            
            direction[name] = delta_normalized
        directions.append(direction)
    
    # 2. 计算平均方向
    avg_direction = {}
    for name in pre_trained.keys():
        avg_direction[name] = sum(
            d[name] for d in directions
        ) / len(directions)
    
    # 3. 对齐保持投影
    aligned_directions = []
    for direction, fisher in zip(directions, fishers):
        aligned = {}
        for name in pre_trained.keys():
            # 投影到平均方向
            dot = (direction[name] * avg_direction[name]).sum(dim=-1, keepdim=True)
            aligned[name] = dot * avg_direction[name]
        aligned_directions.append(aligned)
    
    # 4. Fisher加权融合
    merged = {}
    for name in pre_trained.keys():
        # Fisher归一化权重
        total_fisher = sum(
            fisher[name].sum() for fisher in fishers
        )
        
        # 加权合并
        merged[name] = pre_trained[name].clone()
        for aligned, fisher in zip(aligned_directions, fishers):
            weight = fisher[name].sum() / total_fisher
            merged[name] += weight * aligned[name]
    
    return merged

4. 对齐保持分析

4.1 对齐度量

定义(对齐分数)

4.2 几何解释

┌─────────────────────────────────────────────────────────────────────────┐
│                     Fisher引导的几何解释                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│                        θ* (融合目标)                                     │
│                         ╱│╲                                            │
│                        ╱ │ ╲                                           │
│                       ╱  │  ╲                                          │
│                      ╱   │   ╲                                         │
│                     ╱    │    ╲                                        │
│              θ₁───►     │     ◄──θ₂                                   │
│              (专家1)     │     (专家2)                                  │
│                         │                                               │
│                       θ₀                                                │
│                    (预训练)                                             │
│                                                                          │
│  方向一致: θᵢ - θ₀ 与 θ* - θ₀ 方向相近                                │
│  Fisher引导: 使用Fisher信息调整融合方向                                  │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

5. 实验结果

5.1 对齐保持评估

与预训练模型的对齐分数

方法对齐分数↑困惑度
平均合并0.4215.3
TaskVectors0.3816.1
Fisher合并0.7818.2
AlignMerge0.8514.8

5.2 任务性能

多任务评估

方法数学编程对话平均
专家平均72%68%65%68%
TaskVectors75%71%62%69%
Fisher合并70%65%70%68%
AlignMerge78%74%72%75%

5.3 消融分析

组件对齐分数任务性能
0.4565%
+ 方向对齐0.7271%
+ Fisher加权0.8575%

6. 代码实现

6.1 完整实现

class AlignMergeModelMerging:
    """
    AlignMerge模型融合
    """
    def __init__(self, pre_trained_model):
        self.pre_trained = pre_trained_model
        self.pre_state = {
            k: v.clone() for k, v in pre_trained_model.state_dict().items()
        }
        
    def compute_fisher_diag(self, dataloader, num_samples=1000):
        """
        计算对角Fisher信息
        """
        fisher = {}
        for name, param in self.pre_state.items():
            fisher[name] = torch.zeros_like(param)
        
        sample_count = 0
        self.pre_trained.eval()
        
        for batch in dataloader:
            if sample_count >= num_samples:
                break
            
            outputs = self.pre_trained(batch)
            probs = F.softmax(outputs, dim=-1)
            
            for i in range(len(batch)):
                log_prob = probs[i].mean()
                self.pre_trained.zero_grad()
                log_prob.backward(retain_graph=True)
                
                for name, param in self.pre_trained.named_parameters():
                    if param.grad is not None:
                        fisher[name] += param.grad.data ** 2
                
                sample_count += 1
        
        for name in fisher:
            fisher[name] /= sample_count
            
        return fisher
    
    def merge(self, expert_models, expert_dataloaders=None):
        """
        融合专家模型
        
        Args:
            expert_models: 专家模型列表
            expert_dataloaders: 各专家的Fisher数据加载器
        """
        # 计算各专家的Fisher信息
        fishers = []
        for model, dataloader in zip(expert_models, expert_dataloaders):
            fisher = self.compute_fisher_diag(dataloader)
            fishers.append(fisher)
        
        # 计算方向
        directions = []
        for expert in expert_models:
            direction = {
                name: expert.state_dict()[name] - self.pre_state[name]
                for name in self.pre_state.keys()
            }
            directions.append(direction)
        
        # 计算平均方向
        avg_direction = {}
        for name in self.pre_state.keys():
            avg_direction[name] = sum(d[name] for d in directions) / len(directions)
        
        # 对齐保持投影
        aligned_directions = []
        for direction, fisher in zip(directions, fishers):
            aligned = {}
            for name in self.pre_state.keys():
                # 投影到Fisher归一化的平均方向
                fisher_sqrt = fisher[name].sqrt().clamp(min=1e-8)
                
                # 方向
                dir_normalized = direction[name] / fisher_sqrt
                avg_normalized = avg_direction[name] / fisher_sqrt
                
                # 投影
                dot = (dir_normalized * avg_normalized).sum() / dir_normalized.numel()
                aligned[name] = dot * avg_direction[name]
            aligned_directions.append(aligned)
        
        # Fisher加权融合
        merged = {}
        for name in self.pre_state.keys():
            total_fisher = sum(f[name].sum() for f in fishers)
            
            merged[name] = self.pre_state[name].clone()
            for aligned, fisher in zip(aligned_directions, fishers):
                weight = fisher[name].sum() / total_fisher
                merged[name] += weight * aligned[name]
        
        return merged

7. 总结

7.1 主要贡献

  1. Fisher引导:利用Fisher信息指导融合方向
  2. 对齐保持:保持与预训练模型的对齐关系
  3. 任务性能:在保持对齐的同时提升任务性能

7.2 局限性

  1. 计算开销:需要计算Fisher信息
  2. 对角近似:使用对角Fisher而非完整矩阵
  3. 任务平衡:多任务间的平衡仍需调优

参考文献

Footnotes

  1. AlignMerge: Alignment-Preserving LLM Merging via Fisher-Guided Geometric Constraints, arXiv:2512.16245 2