Task Arithmetic:任务向量算术

1. 核心思想

Task Arithmetic 1 提出了一种优雅的模型合并框架:利用任务向量(Task Vector)来表示和操作模型在特定任务上学到的知识。

1.1 任务向量定义

给定预训练模型参数 和微调模型参数 ,任务向量定义为:

这个差值向量编码了模型在目标任务上学到的知识。

2. 任务向量操作

2.1 加法:合并任务

合并多个任务的向量:

其中 是任务 的权重。

// 任务向量合并
torch::Tensor merge_tasks(
    torch::Tensor pretrain,
    std::vector<torch::Tensor> task_vectors,
    std::vector<float> alphas
) {
    torch::Tensor merged = pretrain.clone();
    for (int i = 0; i < task_vectors.size(); i++) {
        merged += alphas[i] * task_vectors[i];
    }
    return merged;
}

2.2 减法:任务遗忘

通过减去任务向量实现「遗忘」:

这在需要移除有害行为或保护隐私时非常有用。

2.3 缩放:控制强度

通过标量控制任务向量的影响强度:

  • :增强任务能力
  • :温和调整
  • :反向任务

3. 任务冲突与解决

3.1 冲突类型

不同任务向量之间可能存在冲突:

  1. 符号冲突 对同一参数有相反方向的要求
  2. 幅度冲突:不同任务需要不同的调整幅度

3.2 解决方案

方案一:逐元素符号检查

方案二:加权投票

根据各任务的重要性或相关性分配权重:

4. 扩展操作

4.1 任务向量内积

利用内积分析任务间的相关性:

高内积表示任务相关,可以更好地合并;低/负内积表示冲突风险。

4.2 任务向量子空间

研究任务向量是否位于相似的子空间:

如果多个任务共享子空间,它们更容易被有效合并。

5. 与其他方法的结合

5.1 Task Arithmetic + TIES

TIES-Merging 可以视为 Task Arithmetic 的增强版,添加了符号对齐和修剪步骤。

5.2 Task Arithmetic + DARE

DARE 通过稀疏化任务向量来减少冲突,与 Task Arithmetic 正交。

def dare_task_arithmetic(pretrain, task_vectors, alphas, threshold=0.8):
    """DARE增强的Task Arithmetic"""
    merged = pretrain.clone()
    
    for tv, alpha in zip(task_vectors, alphas):
        # 稀疏化:保留最大threshold比例的参数
        flat_tv = (alpha * tv).flatten()
        k = int(len(flat_tv) * threshold)
        indices = torch.argsort(flat_tv.abs())[-k:]
        
        mask = torch.zeros_like(flat_tv)
        mask[indices] = 1
        
        # Rescale
        sparse_tv = flat_tv * mask / threshold
        merged += sparse_tv.reshape_as(merged)
    
    return merged

6. 理论分析

6.1 线性假设

Task Arithmetic 假设任务向量在参数空间中形成线性结构:

这一假设在实践中表现良好,但并非严格成立。

6.2 泛化保证

基于PAC-Bayes理论,可以给出任务向量合并的泛化界:2

7. 应用场景

  1. 多任务学习:合并不同任务的专业能力
  2. 持续学习:将新任务添加到已有模型
  3. 安全对齐:移除有害行为同时保留有用能力
  4. 领域适应:添加领域特定知识

8. 参考资料

Footnotes

  1. Ilharco, G., et al. (2022). Editing models with task arithmetic. ICLR 2023.

  2. Daras, G., et al. (2024). Task Arithmetic for Model Editing. arXiv:2406.11385.