引言
标准 Softmax 自注意力的计算复杂度为 ,其中 为序列长度。这一二次复杂度严重限制了 Transformer 在高分辨率视觉任务(如高分辨率图像分割、检测)中的可扩展性。当输入分辨率从 提升到 (约 33.2 万 Token)时, 的计算代价使得标准注意力机制几乎不可用。
Infinite Self-Attention (InfSA)1 提出了一种根本性的重新定义:不再将注意力视为 Token 间的逐对相似度计算,而是将其建模为内容自适应 Token 图上的谱扩散过程(spectral diffusion process)。这种视角不仅解决了计算效率问题,更揭示了注意力机制与经典图论中**中心性度量(centrality measures)**之间的深层联系。
问题:二次复杂度的瓶颈
标准 Softmax 注意力
给定查询 、键 、值 ,标准自注意力的计算为:
其中 为每头维度。Softmax 操作需要对整个 注意力矩阵进行计算和存储。
高分辨率场景下的困境
| 分辨率 | Token 数 | 标准注意力计算量 | 可行性 |
|---|---|---|---|
| 196 | ✓ 正常 | ||
| 1024 | ⚠ 受限 | ||
| 4096 | ✗ 困难 | ||
| 33.2 万 | ✗ OOM |
核心挑战
- 内存爆炸:注意力矩阵 需要 存储
- 计算爆炸:矩阵乘法 需要 操作
- 梯度反向传播:训练时同样面临 的内存开销
InfSA 谱公式
从相似度到扩散
InfSA 的核心思想是放弃显式计算注意力矩阵,转而关注其谱性质(spectral properties)——即注意力矩阵的特征向量和特征值。
定义隐式注意力算子(implicit attention operator):
注意:这里 不经过 Softmax 归一化,是原始的相似度矩阵。
扩散过程的直觉
将 个 Token 视为图上的节点, 定义了节点间的边权重。信息在图上的传播遵循以下迭代规则:
其中:
- 为 步的隐状态( 维向量)
- 为衰减因子
- 为偏置向量(通常为全 1 向量)
谱扩散的数学定义
InfSA 的核心是求解以下不动点方程:
解析解为:
这就是 Neumann 级数(Neumann series) 的形式。
Neumann 级数:累积多跳交互
定义
Neumann 级数是无穷级数 的闭式求和。当谱半径 时(即 ),级数收敛。
多跳交互的解释
| 幂次 | 矩阵 | 解释 |
|---|---|---|
| 单位矩阵 | Token 的自连接 | |
| 一阶邻居 | 直接交互 | |
| 二阶邻居 | 朋友的朋友 | |
| 阶邻居 | 跳路径上的聚合 |
Neumann 级数 实际上是折扣累积:所有 跳交互的和,每跳乘以 折扣因子。
折扣因子的作用
折扣因子 平衡了:
- 局部性:较小的 强调近邻交互
- 全局性:较大的 允许信息传播到更远的节点
典型的 值在 之间。
计算实现
def infsa_neumann_series(S, b, alpha, num_terms=50):
"""
通过截断 Neumann 级数计算 InfSA 输出
Args:
S: 相似度矩阵 (N x N)
b: 偏置向量 (N,)
alpha: 折扣因子
num_terms: 截断项数
"""
h = b.clone()
power = b.clone()
for t in range(1, num_terms + 1):
power = S @ power # S^t * b
h = h + (alpha ** t) * power
return h问题:直接计算仍需要 存储 。这引出了 Linear-InfSA 的设计。
图中心性连接
InfSA 的谱公式揭示了自注意力与经典图论中**中心性度量(centrality measures)**之间的深刻联系。
1. 特征向量中心性(Eigenvector Centrality)
特征向量中心性定义节点 的重要性为其邻居重要性的加权和:
这等价于求解最大特征值对应的特征向量:
联系:当 时,InfSA 的不动点 投影到主特征向量方向上。
2. Katz 中心性
Katz 中心性是折扣路径计数的直接度量:
联系:InfSA 的 Neumann 级数形式上就是 Katz 中心性,只是通常取 (全 1 向量)。
3. PageRank
PageRank 是 Katz 中心性的随机游走解释:
其中 是阻尼因子。
联系:InfSA 可以视为 ** Personalized PageRank** 的变体,其中折扣因子 扮演阻尼因子的角色。
统一视角
| 方法 | 公式 | InfSA 联系 |
|---|---|---|
| 特征向量中心性 | ||
| Katz 中心性 | 完全匹配 | |
| PageRank | ,归一化版本 |
马尔可夫链解释
吸收马尔可夫链
InfSA 的 Neumann 核与**吸收马尔可夫链(absorbing Markov chain)**的基本矩阵密切相关。
考虑随机游走在 Token 图上:
- 转移概率:
- 吸收态:所有 Token 作为”吸收态”(通过偏置 持续注入概率)
基本矩阵
对于吸收马尔可夫链,**基本矩阵(fundamental matrix)**定义为:
其中 是转移概率矩阵。
中心性的概率解释
基本矩阵的第 行元素:
解释: 是从节点 出发,经过折扣随机游走,期望访问节点 的总次数。
因此,InfSA 中 Token 的中心性得分 可以解释为:
从节点 出发的随机游走,在被”吸收”前,期望访问所有节点的总次数。
直观理解
Token 图上的随机游走:
[A] ---0.8---> [B]
| |
0.5 0.6
v v
[C] ---0.7---> [D]
从 A 出发的随机游走:
- 直接访问 B、C、D:期望 0.5 次
- 访问 B 后再访问 C:期望 0.8×0.5 = 0.4 次
- 访问 C 后再访问 B:期望 0.5×0.8 = 0.4 次
- ...(所有路径的折扣累积)
A 的中心性 = 所有路径的折扣期望访问次数
Linear-InfSA:O(N) 线性变体
核心思想
Linear-InfSA 的关键创新是不显式构造注意力矩阵 ,而是通过固定大小的辅助状态直接追踪 Neumann 级数的不动点。
固定大小辅助状态
定义每个 Token 的辅助状态 ( 是每头维度),满足:
其中 是 Token 的键向量。
迭代形式
def linear_infsa_step(k, u, alpha=0.7):
"""
Linear-InfSA 单步迭代
Args:
k: 键向量 (d_h,)
u: 辅助状态 (d_h,)
alpha: 折扣因子
Returns:
u_new: 更新后的辅助状态
"""
# 缩放后的内积
k_norm = k / (k.norm() + 1e-8)
scale = alpha * (k_norm @ u)
# 折扣累积
u_new = scale * k_norm + k
return u_new复杂度分析
| 组件 | 标准注意力 | Linear-InfSA |
|---|---|---|
| 空间复杂度 | ||
| 时间复杂度(每步) | ||
| 辅助状态 | 无 | ,与 无关 |
| 矩阵存储 | 无需存储 |
伪代码
class LinearInfSA:
"""
Linear-InfSA 实现
关键特性:
- 固定大小辅助状态 O(d_h)
- 无需存储注意力矩阵
- 支持流式推理
"""
def __init__(self, d_model, num_heads, alpha=0.7):
self.d_h = d_model // num_heads
self.alpha = alpha
# QKV 投影
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# 辅助状态(每头每 Token 一个)
# 实际上在 Linear-InfSA 中我们只保留一个全局状态
# 然后对每个 Token 维护其投影
def forward(self, x, state=None):
B, N, D = x.shape
# QKV 投影
Q = self.W_q(x) # (B, N, D)
K = self.W_k(x) # (B, N, D)
V = self.W_v(x) # (B, N, D)
# 辅助状态追踪(核心创新)
if state is None:
# 初始化:所有 Token 的状态为 0
state = torch.zeros(B, N, self.d_h, device=x.device)
# 收集加权值
output = []
for i in range(N):
# 更新状态
k_i = K[:, i, :] # (B, d_h)
u_i = state[:, i, :] # (B, d_h)
# Linear-InfSA 迭代
scale = self.alpha * (k_i * u_i).sum(dim=-1, keepdim=True) / (k_i.norm(dim=-1, keepdim=True) + 1e-8)
u_i_new = scale * k_i + k_i
state[:, i, :] = u_i_new
# 加权聚合
v_i = V[:, i, :]
output.append(v_i)
# 由于线性结构,可以向量化
# 这里展示原理,实践中使用高效实现
return torch.stack(output, dim=1), state训练稳定性
Linear-InfSA 在训练中表现出良好的稳定性:
- 梯度流动:辅助状态的递归定义确保梯度可以有效反向传播
- 谱归一化:折扣因子 保证数值稳定
- Drop-in 兼容:可以替换标准注意力层,无需修改训练流程
实验结果
ImageNet-1K 分类性能
| 模型 | 深度 | 参数量 | GFLOPs@224 | Top-1 Acc |
|---|---|---|---|---|
| Softmax ViT-S | 4 | 22M | 45 | 81.5% |
| Linear-InfSA ViT-S | 4 | 22M | 47 | 84.7% |
| Softmax ViT-B | 12 | 87M | 175 | 84.5% |
| Linear-InfSA ViT-B | 12 | 87M | 178 | 86.2% |
关键发现:+3.2% 的架构增益,源于更好的 Token 重要性建模。
高分辨率推理
| 分辨率 | 标准 ViT | Linear-InfSA ViT |
|---|---|---|
| ✓ 正常运行 | ✓ 正常运行 | |
| ✓ 正常运行 | ✓ 正常运行 | |
| ⚠ 显存不足 | ✓ 正常运行 | |
| ✗ OOM | ✓ 训练稳定 | |
| ✗ OOM | ✓ 唯一完成 |
效率对比(A100 40GB)
| 指标 | 标准 ViT-S | Linear-InfSA ViT-S | 提升 |
|---|---|---|---|
| 吞吐量 | 18 images/s | 231 images/s | 13× |
| 能耗 | 11.3 J/image | 0.87 J/image | 13× |
| 显存占用 | 36 GB | 2.8 GB | 13× |
分布偏移鲁棒性
| 数据集 | Softmax ViT-B | Linear-InfSA ViT-B |
|---|---|---|
| ImageNet-1K | 84.5% | 86.2% |
| ImageNet-V2 | 76.8% | 79.8% |
| 差距 | -7.7% | -6.4% |
发现:InfSA 变体在分布偏移下表现出更好的泛化能力,差距从 7.7% 缩小到 6.4%。
线性近似的准确性
Linear-InfSA 的近似与真实主特征向量的余弦相似度:
这表明 的辅助状态足以精确捕捉 注意力矩阵的谱性质。
与相关工作的联系
线性注意力
InfSA 与其他线性注意力方法(如 线性注意力)的区别在于理论基础:
| 方法 | 近似策略 | 理论基础 |
|---|---|---|
| Performer | 随机特征映射 | 核近似 |
| Linear Attention | 核函数重参数化 | 低秩近似 |
| InfSA | 谱扩散不动点 | 图论/马尔可夫链 |
低秩压缩
InfSA 与 注意力矩阵低秩压缩 共享秩压缩的直觉,但采取不同路径:
- 低秩压缩:显式对 进行 SVD 截断
- InfSA:通过谱不动点隐式捕获主特征结构
稀疏注意力
与 稀疏注意力 的对比:
| 特性 | 稀疏注意力 | InfSA |
|---|---|---|
| 复杂度 | 或 | |
| Token 连接 | 结构化/可学习的稀疏模式 | 全连接但折扣累积 |
| 可解释性 | 显式保留的连接 | 图中心性度量 |
实现要点
折扣因子选择
def compute_optimal_alpha(S, margin=0.1):
"""
计算保证收敛的最优折扣因子
谱半径 ρ(S) 决定了最大可行的 alpha
"""
# 幂迭代估计谱半径
eigenvalues = torch.linalg.eigvals(S)
spectral_radius = eigenvalues.abs().max()
return margin / spectral_radius与 Vision Transformer 的集成
class LinearInfSAViT(nn.Module):
"""
使用 Linear-InfSA 的 Vision Transformer
"""
def __init__(self, d_model, num_heads, num_layers, alpha=0.7):
super().__init__()
self.layers = nn.ModuleList([
InfSALayer(d_model, num_heads, alpha)
for _ in range(num_layers)
])
def forward(self, x):
# Patch embedding
x = self.patch_embed(x) # (B, N, D)
# 添加 CLS token
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls_token, x], dim=1)
# 通过 InfSA 层
for layer in self.layers:
x = layer(x)
return x总结
Infinite Self-Attention 提供了一种根本性的注意力重新定义:
- 谱域视角:将注意力建模为内容自适应 Token 图上的扩散过程
- Neumann 级数:折扣累积多跳交互,从一阶邻居到无穷阶
- 图论连接:与特征向量中心性、Katz 中心性、PageRank 统一
- 马尔可夫解释:Token 中心性 = 随机游走期望访问次数
- 线性实现: 复杂度,固定 辅助状态
- 实证验证:13× 吞吐提升,稳定 分辨率,84.7% ImageNet
InfSA 的核心贡献不仅是效率提升,更重要的是揭示了注意力机制与经典图论之间的深层联系,为理解和改进 Transformer 架构提供了新的理论工具。
参考
相关阅读
- 稀疏注意力与长度外推
- 注意力矩阵低秩压缩与 KV Cache 优化
- 谱图神经网络 — 图谱方法的基础理论