1. 研究背景
1.1 模型融合的任务
模型融合旨在将多个微调模型合并为单一模型1:
# 模型融合示例
model1 = load_model('math_expert') # 数学专家
model2 = load_model('code_expert') # 编程专家
merged = merge(model1, model2) # 融合为全能模型1.2 现有方法的局限
| 方法 | 优势 | 问题 |
|---|---|---|
| 平均合并 | 简单 | 知识冲突 |
| TaskVectors | 保留任务 | 任务干扰 |
| Fisher加权 | 重要性感知 | 计算复杂 |
1.3 方向一致性的洞察
DC-Merge的核心发现1:
知识保留的关键在于参数更新的方向一致性
2. 技术框架
2.1 方向一致性问题
定义(方向一致性):设 是融合目标, 是各专家模型,则方向一致性定义为:
其中 是预训练参数。
2.2 核心算法
def dc_merge(models, pre_trained_model):
"""
DC-Merge: 方向一致性模型融合
Args:
models: 专家模型列表
pre_trained_model: 预训练模型
Returns:
merged_model: 融合后的模型
"""
# 1. 计算各模型的方向(相对于预训练)
directions = []
for model in models:
delta = get_model_delta(model, pre_trained_model)
directions.append(delta / delta.norm())
# 2. 计算平均方向
avg_direction = sum(directions) / len(directions)
avg_direction = avg_direction / avg_direction.norm()
# 3. 方向一致性加权
consistency_weights = []
for d in directions:
weight = torch.dot(d, avg_direction).item() # 方向一致性分数
consistency_weights.append(max(weight, 0)) # 去除负相关
# 4. 加权融合
weights = F.softmax(torch.tensor(consistency_weights), dim=0)
merged_state = {}
for key in models[0].state_dict().keys():
merged_state[key] = sum(
w * model.state_dict()[key]
for w, model in zip(weights, models)
)
merged_model = copy.deepcopy(pre_trained_model)
merged_model.load_state_dict(merged_state)
return merged_model3. 深入分析
3.1 方向一致性与知识保留
┌─────────────────────────────────────────────────────────────────────────┐
│ 方向一致性与知识保留 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 情况1: 方向一致 │
│ │
│ θ₀ ──────────────────► θ₁ │
│ ──────────────────► θ₂ │
│ ──────────────────► θ* │
│ │
│ cos(θ*-θ₀, θ₁-θ₀) ≈ 1, cos(θ*-θ₀, θ₂-θ₀) ≈ 1 │
│ → 融合效果好,保留两个专家的知识 │
│ │
│ 情况2: 方向冲突 │
│ │
│ θ₀ ──────────────────► θ₁ │
│ ◄────────────────── θ₂ │
│ │
│ cos(θ*-θ₀, θ₁-θ₀) ≈ 1, cos(θ*-θ₀, θ₂-θ₀) ≈ -1 │
│ → 融合效果差,需要处理冲突 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
3.2 数学推导
定理(方向一致性上界):融合模型的效用满足:
其中 是融合权重。
4. 代码实现
4.1 完整实现
class DCMerge:
"""
DC-Merge: 方向一致性模型融合
"""
def __init__(self, pre_trained_model):
self.pre_trained = pre_trained_model
self.pre_state = {k: v.clone() for k, v in pre_trained_model.state_dict().items()}
def merge(self, expert_models, temperature=1.0):
"""
融合多个专家模型
Args:
expert_models: 专家模型列表
temperature: 温度参数,控制权重分布
"""
# 计算方向
directions = []
for model in expert_models:
delta = {
k: (model.state_dict()[k] - self.pre_state[k]).flatten()
for k in model.state_dict().keys()
}
# 拼接所有参数方向
direction = torch.cat([d.flatten() for d in delta.values()])
directions.append(direction / direction.norm())
# 计算方向一致性
avg_direction = sum(directions) / len(directions)
avg_direction = avg_direction / avg_direction.norm()
# 计算一致性分数
consistency_scores = []
for d in directions:
score = torch.dot(d, avg_direction).item()
consistency_scores.append(score)
# Softmax权重
weights = F.softmax(
torch.tensor(consistency_scores) / temperature,
dim=0
)
# 加权合并
merged_state = {}
for key in self.pre_state.keys():
# 计算各模型的增量
deltas = []
for model in expert_models:
delta = model.state_dict()[key] - self.pre_state[key]
deltas.append(delta)
# 加权平均
merged_delta = sum(w * d for w, d in zip(weights, deltas))
merged_state[key] = self.pre_state[key] + merged_delta
# 应用合并
merged_model = copy.deepcopy(self.pre_trained)
merged_model.load_state_dict(merged_state)
return merged_model, weights.numpy()4.2 层级别融合
class LayerwiseDCMerge:
"""
层级别的DC-Merge
不同层使用不同的融合策略
"""
def __init__(self, pre_trained_model):
self.pre_trained = pre_trained_model
def merge_layerwise(self, expert_models):
"""
按层融合
"""
merged_state = {}
# 按层遍历
for name, param in self.pre_trained.named_parameters():
# 获取所有模型的该层参数
layer_deltas = []
for model in expert_models:
expert_param = model.state_dict()[name]
pre_param = self.pre_trained.state_dict()[name]
delta = expert_param - pre_param
layer_deltas.append(delta.flatten())
# 拼接方向
layer_directions = [d / d.norm() for d in layer_deltas]
# 计算该层的方向一致性
if len(layer_directions) > 1:
avg_dir = sum(layer_directions) / len(layer_directions)
avg_dir = avg_dir / avg_dir.norm()
scores = [torch.dot(d, avg_dir).item() for d in layer_directions]
weights = F.softmax(torch.tensor(scores), dim=0)
else:
weights = torch.tensor([1.0])
# 加权合并
merged_delta = sum(w * d for w, d in zip(weights, layer_deltas))
merged_param = self.pre_trained.state_dict()[name] + merged_delta.view_as(
self.pre_trained.state_dict()[name]
)
merged_state[name] = merged_param
return merged_state5. 实验结果
5.1 知识保留评估
专家能力测试:
| 模型 | 数学 | 编程 | 科学 | 平均 |
|---|---|---|---|---|
| Math-Expert | 92% | 45% | 58% | 65% |
| Code-Expert | 42% | 89% | 55% | 62% |
| Sci-Expert | 55% | 48% | 91% | 65% |
| 平均合并 | 58% | 55% | 62% | 58% |
| DC-Merge | 78% | 72% | 74% | 75% |
5.2 任务干扰分析
| 方法 | 任务间干扰 | 知识冲突率 |
|---|---|---|
| 平均合并 | 高 | 45% |
| TaskVectors | 中 | 28% |
| Fisher合并 | 低 | 15% |
| DC-Merge | 最低 | 8% |
6. 总结
6.1 主要贡献
- 方向一致性理论:建立融合质量与参数方向的关系
- 简单有效:无需复杂计算
- 知识保留:显著减少任务干扰
6.2 局限性
- 忽略幅度:只考虑方向,未考虑更新幅度
- 层级差异:假设所有层同等重要
- 在线学习:未考虑持续学习场景