概述
Transformer架构已成为深度学习的基石,但其**表达能力(Expressivity)**的理论边界长期缺乏完整理解。本文系统性地分析注意力机制的计算能力,回答一个核心问题:Transformer能计算什么?不能计算什么?123
主要内容包括:
- 注意力机制的计算模型形式化
- 表达能力下界:Transformer能做什么
- 表达能力上界:Transformer的限制
- 深度与宽度的权衡
- 与其他计算模型(电路复杂度、形式语言)的联系
计算模型形式化
Transformer层作为计算单元
单层Transformer可以形式化为:
其中Attention操作:
Turing完备性讨论
关键问题:多层Transformer是否Turing完备?
结论:是的,但需要特定条件。
定理1(Turing完备性):
具有以下组件的多层Transformer是Turing完备的:
- 足够深度(线性增长于输入规模)
- 任意精度的数值表示
- 适当的位置编码
注意:标准浮点精度下,多层Transformer的Turing完备性是一个开放问题。
表达能力下界
什么是Transformer可以计算的?
正则语言识别
定理2(正则语言识别):
固定深度的Transformer可以识别所有正则语言。
具体地:
- 深度:识别星号-闭包语言(如 )
- 深度:识别所有正则语言
# 示例:识别 a^n b^n 的Transformer伪代码
class A_N_B_N_Recognizer(torch.nn.Module):
"""
识别语言 {a^n b^n | n >= 0}
使用注意力追踪a的数量,然后验证b的数量匹配
"""
def __init__(self, d_model=128, n_heads=4):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads)
self.ffn = torch.nn.Sequential(
torch.nn.Linear(d_model, d_model * 4),
torch.nn.GELU(),
torch.nn.Linear(d_model * 4, 2) # accept/reject
)
def forward(self, x, mask=None):
# x: (batch, seq_len, d_model) 其中 x[i] 编码 token a 或 b
# 使用注意力追踪 a 的位置
# ...
# 理论保证:存在一个深度为 O(1) 的Transformer识别此语言算术运算
定理3(加法运算):
深度的Transformer可以计算两个 位整数的加法。1
定理4(乘法运算):
深度的Transformer可以计算两个 位整数的乘法。
计数能力
定理5(精确计数):
深度的Transformer可以精确计数序列中某类token的数量。
def counting_attention(query_pattern, keys):
"""
计数注意力模式
query: 询问计数
keys: 输入token
输出: 计数结果
"""
# 计算每个key与查询的相似度
similarity = torch.matmul(query_pattern, keys.T)
# Softmax归一化
attention = F.softmax(similarity, dim=-1)
# 加权和 = 计数
count = torch.sum(attention, dim=-1)
return count表达能力上界
Transformer的限制
定理6(串行化限制)
标准Transformer无法高效计算某些并行可计算的问题。
例如:
- 某些需要在 深度内完成的问题
- 需要非均匀信息传递的问题
电路复杂度视角
NC¹ vs TC⁰
| 复杂度类 | 定义 | Transformer表达能力 |
|---|---|---|
| AC⁰ | 常数深度、无限制扇入 | 部分满足(需要特殊设计) |
| NC¹ | 对数深度、多项式大小 | 固定深度Transformer无法达到 |
| TC⁰ | 阈值电路 | 更强于标准Transformer |
| AC⁰[⊕] | 常数深度、异或门 | 某些Transformer可达 |
定理7(深度限制):
深度的Transformer的表达能力等价于某个非均匀AC⁰电路。
正则语言的理论限制
定理8(限制):
固定深度Transformer无法区分某些正则语言对。
例如:
- vs
- 需要跨距离的精确计数,超出固定深度Transformer能力
位置编码依赖性
定理9(位置编码重要性):
注意力的表达能力严重依赖位置编码。
- 无位置编码:无法区分序列顺序(排列不变)
- 绝对位置编码:表达能力受限于编码容量
- 相对位置编码:可建模更复杂的序列关系
深度-宽度权衡
表达能力与资源权衡
定理10(深度-宽度权衡):
对于识别长度为 的序列上的任意布尔函数:
- 深度 与 宽度 满足权衡关系
- 增加深度可减少所需宽度,反之亦然
最优深度分析
def compute_optimal_depth(n, circuit_class='NC1'):
"""
计算识别n位输入的最优深度
假设使用d层Transformer,每层可计算AC^0电路
理论结果:
- AC^0: d = O(1)
- NC^1: d = O(log n)
- TC^0: d = O(log n)
"""
if circuit_class == 'AC0':
return float('inf') # AC^0 无法识别所有NC^1函数
elif circuit_class == 'NC1':
return int(math.log2(n)) + 1 # 对数深度
elif circuit_class == 'TC0':
return int(math.log2(n)) + 1
else:
return None宽度受限的深度下界
定理11(深度下界):
若Transformer宽度受限(如 ),则:
才能识别所有正则语言。
与电路复杂度的联系
统一框架
Transformer的计算能力可以用电路复杂度来刻画:
| Transformer配置 | 电路复杂度类 | 能计算的问题 |
|---|---|---|
| 深度 | AC⁰ | 简单模式匹配 |
| 深度+非线性 | TC⁰子集 | 计数、多数函数 |
| 深度 | NC¹ | 算术运算、树结构 |
| 深度 | AC¹ | 更复杂计算 |
证明技术
电路复杂度方法的核心:
- 随机 Restrictions:通过随机删除输入变量,证明深度下界
- 通信复杂性:将电路下界转化为通信复杂性下界
- 多项式方法:利用代数方法证明下界
与形式语言的联系
Chomsky层级中的表达能力
| 语言类 | 自动机 | Transformer深度需求 |
|---|---|---|
| 正则语言 | DFA/NFA | |
| 上下文无关 | PDA | |
| 上下文相关 | LBA | |
| 递归可枚举 | 图灵机 | 无限 |
形式语言识别的深度需求
class LanguageComplexityAnalyzer:
"""
分析语言复杂度与Transformer深度需求
"""
COMPLEXITY_MAP = {
'regular': 1, # O(1) 深度
'star-free': 1, # O(1) 深度
'odd-length': 1, # O(1) 深度
'a^n b^n': 1, # O(1) 深度(通过计数)
'a^n b^n c^n': 2, # 需要 O(log n) 深度
'palindrome': 2, # 需要 O(log n) 深度
'balanced-parens': 2, # 需要 O(log n) 深度
}
@classmethod
def get_minimum_depth(cls, language_type):
return cls.COMPLEXITY_MAP.get(language_type, None)实际应用:表达能力工程
设计原则
-
任务-架构匹配:
- 简单模式匹配 → 浅层Transformer
- 复杂计数/算术 → 深层Transformer
-
归纳偏置注入:
- 结构先验 → 减少所需深度
- 位置编码 → 捕获序列依赖
-
模块化设计:
- 专用注意力头处理特定模式
- 分层抽象减少总深度需求
实践建议
class ExpressivityAwareTransformer(torch.nn.Module):
"""
表达能力感知的Transformer设计
"""
def __init__(self, task_complexity, d_model=512, n_heads=8):
super().__init__()
# 根据任务复杂度选择深度
if task_complexity == 'low':
n_layers = 4
elif task_complexity == 'medium':
n_layers = 12
elif task_complexity == 'high':
n_layers = 24
else:
raise ValueError(f"Unknown complexity: {task_complexity}")
self.layers = torch.nn.ModuleList([
TransformerLayer(d_model, n_heads)
for _ in range(n_layers)
])
# 根据任务添加特定归纳偏置
if task_complexity in ['medium', 'high']:
self.add_structural_bias()
def add_structural_bias(self):
# 添加结构化归纳偏置(如相对位置编码)
pass开放问题
- 固定精度Turing完备性:标准浮点精度下的Turing完备性仍未证明
- 精确电路等价:Transformer与电路类的精确对应关系
- 训练动态与表达能力:训练是否能达到理论表达上限
- 高效近似:如何在减少资源消耗的同时保持表达能力