1. 研究背景
1.1 模型融合的挑战
大语言模型的融合面临独特挑战1:
- 参数规模大:数十亿参数难以直接操作
- 任务多样性:不同任务可能有冲突
- 对齐要求:需要保持预训练的对齐特性
1.2 现有方法的局限
| 方法 | 对齐保持 | 任务性能 |
|---|---|---|
| 平均合并 | ❌ | 中 |
| TaskVectors | ❌ | 中 |
| Fisher合并 | ✅ | 低 |
| AlignMerge | ✅ | 高 |
2. AlignMerge核心思想
2.1 核心洞察
AlignMerge的核心发现1:
Fisher信息可以引导模型融合,使融合后的模型保持与预训练模型的对齐关系
2.2 技术框架
┌─────────────────────────────────────────────────────────────────────────┐
│ AlignMerge 框架 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 输入: │
│ - 预训练模型 θ₀ │
│ - 多个专家模型 θ₁, θ₂, ..., θₖ │
│ - 各模型的Fisher信息 F₁, F₂, ..., Fₖ │
│ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Fisher引导的空间变换 │ │
│ │ │ │
│ │ 对齐保持: 确保 (θᵢ - θ₀) 方向与 Fisher 方向一致 │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 几何约束融合 │ │
│ │ │ │
│ │ min || θ - θ₀ ||²_F s.t. θ与各专家保持对齐 │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 输出: θ* │
│ │
└─────────────────────────────────────────────────────────────────────────┘
3. 技术细节
3.1 Fisher信息计算
class FisherInformationCalculator:
"""
Fisher信息计算器
"""
def __init__(self, model, dataloader):
self.model = model
self.dataloader = dataloader
def compute_fisher(self, num_samples=1000):
"""
计算Fisher信息矩阵
Fisher信息定义:
F = E[∇log p(y|x,θ) ∇log p(y|x,θ)^T]
"""
fisher = {}
sample_count = 0
self.model.eval()
for name, param in self.model.named_parameters():
fisher[name] = torch.zeros_like(param)
for batch in self.dataloader:
if sample_count >= num_samples:
break
# 前向传播
outputs = self.model(batch)
# 获取对数似然梯度
log_probs = F.log_softmax(outputs, dim=-1)
for y in batch['labels']:
log_prob = log_probs[range(len(y)), y].sum()
# 反向传播
self.model.zero_grad()
log_prob.backward()
# 累积Fisher信息
for name, param in self.model.named_parameters():
if param.grad is not None:
fisher[name] += param.grad.data ** 2
sample_count += 1
# 归一化
for name in fisher:
fisher[name] /= sample_count
return fisher3.2 Fisher引导的融合
def align_merge(pre_trained, experts, fishers):
"""
AlignMerge: Fisher引导的模型融合
Args:
pre_trained: 预训练模型状态
experts: 专家模型状态列表
fishers: Fisher信息列表
"""
# 1. 计算方向
directions = []
for expert, fisher in zip(experts, fishers):
direction = {}
for name in pre_trained.keys():
# 增量
delta = expert[name] - pre_trained[name]
# Fisher归一化
fisher_norm = fisher[name].sqrt().clamp(min=1e-8)
delta_normalized = delta / fisher_norm
direction[name] = delta_normalized
directions.append(direction)
# 2. 计算平均方向
avg_direction = {}
for name in pre_trained.keys():
avg_direction[name] = sum(
d[name] for d in directions
) / len(directions)
# 3. 对齐保持投影
aligned_directions = []
for direction, fisher in zip(directions, fishers):
aligned = {}
for name in pre_trained.keys():
# 投影到平均方向
dot = (direction[name] * avg_direction[name]).sum(dim=-1, keepdim=True)
aligned[name] = dot * avg_direction[name]
aligned_directions.append(aligned)
# 4. Fisher加权融合
merged = {}
for name in pre_trained.keys():
# Fisher归一化权重
total_fisher = sum(
fisher[name].sum() for fisher in fishers
)
# 加权合并
merged[name] = pre_trained[name].clone()
for aligned, fisher in zip(aligned_directions, fishers):
weight = fisher[name].sum() / total_fisher
merged[name] += weight * aligned[name]
return merged4. 对齐保持分析
4.1 对齐度量
定义(对齐分数):
4.2 几何解释
┌─────────────────────────────────────────────────────────────────────────┐
│ Fisher引导的几何解释 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ θ* (融合目标) │
│ ╱│╲ │
│ ╱ │ ╲ │
│ ╱ │ ╲ │
│ ╱ │ ╲ │
│ ╱ │ ╲ │
│ θ₁───► │ ◄──θ₂ │
│ (专家1) │ (专家2) │
│ │ │
│ θ₀ │
│ (预训练) │
│ │
│ 方向一致: θᵢ - θ₀ 与 θ* - θ₀ 方向相近 │
│ Fisher引导: 使用Fisher信息调整融合方向 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
5. 实验结果
5.1 对齐保持评估
与预训练模型的对齐分数:
| 方法 | 对齐分数↑ | 困惑度 |
|---|---|---|
| 平均合并 | 0.42 | 15.3 |
| TaskVectors | 0.38 | 16.1 |
| Fisher合并 | 0.78 | 18.2 |
| AlignMerge | 0.85 | 14.8 |
5.2 任务性能
多任务评估:
| 方法 | 数学 | 编程 | 对话 | 平均 |
|---|---|---|---|---|
| 专家平均 | 72% | 68% | 65% | 68% |
| TaskVectors | 75% | 71% | 62% | 69% |
| Fisher合并 | 70% | 65% | 70% | 68% |
| AlignMerge | 78% | 74% | 72% | 75% |
5.3 消融分析
| 组件 | 对齐分数 | 任务性能 |
|---|---|---|
| 无 | 0.45 | 65% |
| + 方向对齐 | 0.72 | 71% |
| + Fisher加权 | 0.85 | 75% |
6. 代码实现
6.1 完整实现
class AlignMergeModelMerging:
"""
AlignMerge模型融合
"""
def __init__(self, pre_trained_model):
self.pre_trained = pre_trained_model
self.pre_state = {
k: v.clone() for k, v in pre_trained_model.state_dict().items()
}
def compute_fisher_diag(self, dataloader, num_samples=1000):
"""
计算对角Fisher信息
"""
fisher = {}
for name, param in self.pre_state.items():
fisher[name] = torch.zeros_like(param)
sample_count = 0
self.pre_trained.eval()
for batch in dataloader:
if sample_count >= num_samples:
break
outputs = self.pre_trained(batch)
probs = F.softmax(outputs, dim=-1)
for i in range(len(batch)):
log_prob = probs[i].mean()
self.pre_trained.zero_grad()
log_prob.backward(retain_graph=True)
for name, param in self.pre_trained.named_parameters():
if param.grad is not None:
fisher[name] += param.grad.data ** 2
sample_count += 1
for name in fisher:
fisher[name] /= sample_count
return fisher
def merge(self, expert_models, expert_dataloaders=None):
"""
融合专家模型
Args:
expert_models: 专家模型列表
expert_dataloaders: 各专家的Fisher数据加载器
"""
# 计算各专家的Fisher信息
fishers = []
for model, dataloader in zip(expert_models, expert_dataloaders):
fisher = self.compute_fisher_diag(dataloader)
fishers.append(fisher)
# 计算方向
directions = []
for expert in expert_models:
direction = {
name: expert.state_dict()[name] - self.pre_state[name]
for name in self.pre_state.keys()
}
directions.append(direction)
# 计算平均方向
avg_direction = {}
for name in self.pre_state.keys():
avg_direction[name] = sum(d[name] for d in directions) / len(directions)
# 对齐保持投影
aligned_directions = []
for direction, fisher in zip(directions, fishers):
aligned = {}
for name in self.pre_state.keys():
# 投影到Fisher归一化的平均方向
fisher_sqrt = fisher[name].sqrt().clamp(min=1e-8)
# 方向
dir_normalized = direction[name] / fisher_sqrt
avg_normalized = avg_direction[name] / fisher_sqrt
# 投影
dot = (dir_normalized * avg_normalized).sum() / dir_normalized.numel()
aligned[name] = dot * avg_direction[name]
aligned_directions.append(aligned)
# Fisher加权融合
merged = {}
for name in self.pre_state.keys():
total_fisher = sum(f[name].sum() for f in fishers)
merged[name] = self.pre_state[name].clone()
for aligned, fisher in zip(aligned_directions, fishers):
weight = fisher[name].sum() / total_fisher
merged[name] += weight * aligned[name]
return merged7. 总结
7.1 主要贡献
- Fisher引导:利用Fisher信息指导融合方向
- 对齐保持:保持与预训练模型的对齐关系
- 任务性能:在保持对齐的同时提升任务性能
7.2 局限性
- 计算开销:需要计算Fisher信息
- 对角近似:使用对角Fisher而非完整矩阵
- 任务平衡:多任务间的平衡仍需调优