引言

组合泛化(Compositional Generalization)是人类智力的核心特征之一:指从已知组件构建新表达的能力。形式上,给定有限的原语(primitives)集合,组合泛化使人类能够产生和理解无限数量的新组合。

对于Transformer语言模型,组合泛化意味着:模型能否将训练中见过的子任务/函数组合起来,处理训练中未出现的组合?这一问题对于构建真正理解语言和执行复杂推理的AI系统至关重要。1

组合泛化的形式化

组合结构定义

设原语集合为 ,每个原语是一个从输入到输出的映射。组合操作 允许组合这些原语:

组合泛化任务:给定训练分布 ,其中仅包含单个原语的应用和某些特定组合,测试分布 包含新组合的评估。

数据生成过程

研究者设计了可解释的合成数据生成过程来研究组合泛化:

# 伪代码:组合任务数据生成
def generate_compositional_task():
    # 原语集合
    primitives = [reverse, double, add_n, subtract_n]
    
    # 训练:部分组合
    train_combos = [
        [reverse], [double], [add_5],
        [reverse, double], [double, add_3]
    ]
    
    # 测试:新组合(未在训练中出现)
    test_combos = [
        [add_5, reverse],      # 新顺序
        [reverse, add_5, double],  # 更长组合
        [subtract_7, reverse]  # 新原语组合
    ]

自回归Transformer的组合能力

实验设置

研究者训练自回归Transformer模型,使用标准的next-token预测目标,在包含原语组合的序列上训练。1

关键发现

  1. 组合结构可学习:自回归Transformer能够从少量数据中学习组合结构,并泛化到指数级或组合级数量的新函数
  2. 指数级泛化:模型不仅能泛化见过的组合,还能泛化到数量远超训练样本的未见组合
  3. 组合效率差异:不同原语组合方式的效率差异显著

中间输出的作用

实验对比了两种生成策略:

策略描述组合泛化效果
无中间输出直接端到端预测最终结果较弱
生成中间输出在序列中插入每个原语的输出显著更强

机制解释:生成中间输出迫使模型显式地分解组合计算,每一步专注于执行单个原语。这种分解使得:

  1. 训练信号更清晰
  2. 错误更容易定位和修正
  3. 组合结构更易被模型”理解”

训练数据偏差的影响

组合顺序偏差

实验揭示了一个重要现象:训练数据中组合的顺序偏差会显著影响模型的组合泛化能力。

设训练分布中组合 出现的频率为 。研究发现:

  • 不平衡时,模型倾向于优先执行高频组合
  • 某些低频组合即使在测试时给出,也可能被忽略

偏差修正策略

为缓解组合顺序偏差,可以采用:

  1. 平衡采样:确保训练中各组合出现频率相近
  2. 课程学习:从简单组合逐渐引入复杂组合
  3. 数据增强:通过变换生成更多样的组合样本

注意力层与前馈层的角色

角色分工发现

通过分析Transformer的内部机制,研究者发现:

  • 注意力层(Attention Layers):负责选择要应用的”能力”
  • 前馈层(Feed-Forward Layers):负责执行选定的能力

机制解释

以一个简单的”加法和反转”组合为例:

输入: "add_5 to string, then reverse"

注意力层作用:
1. 识别需要执行的两个操作
2. 确定操作的相对顺序
3. 为每个操作分配计算资源

前馈层作用:
1. 实现"加5"的数值计算
2. 实现"反转"的字符串操作

电路分析

通过因果追踪(Causal Tracing)和路径修补(Path Patching)技术,研究者识别出执行组合任务的关键电路:

  1. 任务路由电路:识别需要执行的操作序列
  2. 操作选择电路:根据当前上下文选择正确操作
  3. 数据流电路:在操作间传递中间结果
┌─────────────────────────────────────────────────────────────┐
│                      任务路由电路                             │
│  输入 → 注意力头1 → 任务嵌入 → 注意力头2 → 操作选择信号        │
└─────────────────────────────────────────────────────────────┘
                            ↓
┌─────────────────────────────────────────────────────────────┐
│                      操作执行电路                             │
│  操作选择信号 → MLP层1 → 数值计算 → MLP层2 → 字符串操作       │
└─────────────────────────────────────────────────────────────┘

组合能力的可扩展性

样本效率

组合泛化任务的样本效率是一个关键问题。研究发现:

  • 简单组合:O(1) 样本复杂度(通过少量示例即可泛化)
  • 复杂组合:O(k) 样本复杂度,其中 k 是原语数量
  • 深层组合:O(2^d) 样本复杂度,d 是组合深度

与系统泛化的关系

组合泛化与系统泛化(Systematic Generalization)密切相关:

维度组合泛化系统泛化
定义新组合的泛化完整系统的泛化
关注点原语的重组语义结构的泛化
挑战组合爆炸结构一致性

实践应用

改进组合泛化的策略

基于研究发现,可以采用以下策略提升Transformer的组合泛化能力:

1. 中间输出监督

在训练时强制模型生成中间输出:

原始输出: [final_result]
改进输出: [intermediate_1, intermediate_2, ..., final_result]

2. 组合意识训练

  • 显式增加多样化组合的样本
  • 使用课程学习从简单到复杂
  • 引入对比学习区分正确/错误组合

3. 模块化架构

将Transformer分解为:

  • 原语执行器:独立执行每个原语
  • 组合器:学习如何组合原语执行器

代码示例:组合感知训练

class CompositionalLoss(nn.Module):
    """
    组合感知损失函数,鼓励模型学习组合结构
    """
    def __init__(self, alpha=0.5):
        super().__init__()
        self.alpha = alpha  # 中间输出权重
        
    def forward(self, outputs, targets, intermediate_outputs=None):
        # 最终输出损失
        final_loss = F.cross_entropy(outputs, targets)
        
        # 如果有中间输出,增加组合损失
        if intermediate_outputs is not None:
            # 鼓励中间输出与预期中间结果一致
            compositional_loss = F.mse_loss(
                intermediate_outputs, 
                expected_intermediates
            )
            return final_loss + self.alpha * compositional_loss
        return final_loss

与其他泛化类型的比较

泛化类型定义Transformer表现挑战
分布内泛化同分布测试优秀最小
分布外泛化偏移分布中等中等
组合泛化新组合可变较高
系统泛化结构泛化较弱最高

总结

自回归Transformer的组合泛化研究揭示了几个关键发现:

  1. 能力存在:Transformer确实具有组合泛化能力,能够从有限训练中泛化到指数级新组合
  2. 中间输出关键:生成中间输出显著提升组合泛化性能
  3. 角色分工:注意力层负责选择,前馈层负责执行
  4. 偏差敏感:训练数据中的组合偏差显著影响泛化

这些发现为构建更强组合能力的语言模型提供了理论和实践指导。

参考

Footnotes

  1. Compositional Capabilities of Autoregressive Transformers: A Study on Synthetic, Interpretable Tasks. ICML 2024 2