引言
组合泛化(Compositional Generalization)是人类智力的核心特征之一:指从已知组件构建新表达的能力。形式上,给定有限的原语(primitives)集合,组合泛化使人类能够产生和理解无限数量的新组合。
对于Transformer语言模型,组合泛化意味着:模型能否将训练中见过的子任务/函数组合起来,处理训练中未出现的组合?这一问题对于构建真正理解语言和执行复杂推理的AI系统至关重要。1
组合泛化的形式化
组合结构定义
设原语集合为 ,每个原语是一个从输入到输出的映射。组合操作 允许组合这些原语:
组合泛化任务:给定训练分布 ,其中仅包含单个原语的应用和某些特定组合,测试分布 包含新组合的评估。
数据生成过程
研究者设计了可解释的合成数据生成过程来研究组合泛化:
# 伪代码:组合任务数据生成
def generate_compositional_task():
# 原语集合
primitives = [reverse, double, add_n, subtract_n]
# 训练:部分组合
train_combos = [
[reverse], [double], [add_5],
[reverse, double], [double, add_3]
]
# 测试:新组合(未在训练中出现)
test_combos = [
[add_5, reverse], # 新顺序
[reverse, add_5, double], # 更长组合
[subtract_7, reverse] # 新原语组合
]自回归Transformer的组合能力
实验设置
研究者训练自回归Transformer模型,使用标准的next-token预测目标,在包含原语组合的序列上训练。1
关键发现:
- 组合结构可学习:自回归Transformer能够从少量数据中学习组合结构,并泛化到指数级或组合级数量的新函数
- 指数级泛化:模型不仅能泛化见过的组合,还能泛化到数量远超训练样本的未见组合
- 组合效率差异:不同原语组合方式的效率差异显著
中间输出的作用
实验对比了两种生成策略:
| 策略 | 描述 | 组合泛化效果 |
|---|---|---|
| 无中间输出 | 直接端到端预测最终结果 | 较弱 |
| 生成中间输出 | 在序列中插入每个原语的输出 | 显著更强 |
机制解释:生成中间输出迫使模型显式地分解组合计算,每一步专注于执行单个原语。这种分解使得:
- 训练信号更清晰
- 错误更容易定位和修正
- 组合结构更易被模型”理解”
训练数据偏差的影响
组合顺序偏差
实验揭示了一个重要现象:训练数据中组合的顺序偏差会显著影响模型的组合泛化能力。
设训练分布中组合 出现的频率为 。研究发现:
- 当 与 不平衡时,模型倾向于优先执行高频组合
- 某些低频组合即使在测试时给出,也可能被忽略
偏差修正策略
为缓解组合顺序偏差,可以采用:
- 平衡采样:确保训练中各组合出现频率相近
- 课程学习:从简单组合逐渐引入复杂组合
- 数据增强:通过变换生成更多样的组合样本
注意力层与前馈层的角色
角色分工发现
通过分析Transformer的内部机制,研究者发现:
- 注意力层(Attention Layers):负责选择要应用的”能力”
- 前馈层(Feed-Forward Layers):负责执行选定的能力
机制解释
以一个简单的”加法和反转”组合为例:
输入: "add_5 to string, then reverse"
注意力层作用:
1. 识别需要执行的两个操作
2. 确定操作的相对顺序
3. 为每个操作分配计算资源
前馈层作用:
1. 实现"加5"的数值计算
2. 实现"反转"的字符串操作
电路分析
通过因果追踪(Causal Tracing)和路径修补(Path Patching)技术,研究者识别出执行组合任务的关键电路:
- 任务路由电路:识别需要执行的操作序列
- 操作选择电路:根据当前上下文选择正确操作
- 数据流电路:在操作间传递中间结果
┌─────────────────────────────────────────────────────────────┐
│ 任务路由电路 │
│ 输入 → 注意力头1 → 任务嵌入 → 注意力头2 → 操作选择信号 │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ 操作执行电路 │
│ 操作选择信号 → MLP层1 → 数值计算 → MLP层2 → 字符串操作 │
└─────────────────────────────────────────────────────────────┘
组合能力的可扩展性
样本效率
组合泛化任务的样本效率是一个关键问题。研究发现:
- 简单组合:O(1) 样本复杂度(通过少量示例即可泛化)
- 复杂组合:O(k) 样本复杂度,其中 k 是原语数量
- 深层组合:O(2^d) 样本复杂度,d 是组合深度
与系统泛化的关系
组合泛化与系统泛化(Systematic Generalization)密切相关:
| 维度 | 组合泛化 | 系统泛化 |
|---|---|---|
| 定义 | 新组合的泛化 | 完整系统的泛化 |
| 关注点 | 原语的重组 | 语义结构的泛化 |
| 挑战 | 组合爆炸 | 结构一致性 |
实践应用
改进组合泛化的策略
基于研究发现,可以采用以下策略提升Transformer的组合泛化能力:
1. 中间输出监督
在训练时强制模型生成中间输出:
原始输出: [final_result]
改进输出: [intermediate_1, intermediate_2, ..., final_result]
2. 组合意识训练
- 显式增加多样化组合的样本
- 使用课程学习从简单到复杂
- 引入对比学习区分正确/错误组合
3. 模块化架构
将Transformer分解为:
- 原语执行器:独立执行每个原语
- 组合器:学习如何组合原语执行器
代码示例:组合感知训练
class CompositionalLoss(nn.Module):
"""
组合感知损失函数,鼓励模型学习组合结构
"""
def __init__(self, alpha=0.5):
super().__init__()
self.alpha = alpha # 中间输出权重
def forward(self, outputs, targets, intermediate_outputs=None):
# 最终输出损失
final_loss = F.cross_entropy(outputs, targets)
# 如果有中间输出,增加组合损失
if intermediate_outputs is not None:
# 鼓励中间输出与预期中间结果一致
compositional_loss = F.mse_loss(
intermediate_outputs,
expected_intermediates
)
return final_loss + self.alpha * compositional_loss
return final_loss与其他泛化类型的比较
| 泛化类型 | 定义 | Transformer表现 | 挑战 |
|---|---|---|---|
| 分布内泛化 | 同分布测试 | 优秀 | 最小 |
| 分布外泛化 | 偏移分布 | 中等 | 中等 |
| 组合泛化 | 新组合 | 可变 | 较高 |
| 系统泛化 | 结构泛化 | 较弱 | 最高 |
总结
自回归Transformer的组合泛化研究揭示了几个关键发现:
- 能力存在:Transformer确实具有组合泛化能力,能够从有限训练中泛化到指数级新组合
- 中间输出关键:生成中间输出显著提升组合泛化性能
- 角色分工:注意力层负责选择,前馈层负责执行
- 偏差敏感:训练数据中的组合偏差显著影响泛化
这些发现为构建更强组合能力的语言模型提供了理论和实践指导。