Circular-Convolutional Attention(CAT)
核心论文:Yoshihiro Yamada. CAT: Circular-Convolutional Attention for Sub-Quadratic Transformers. NeurIPS 2025.
arXiv:2504.06704 | OpenReview:q7hoTSbV1t
作者机构:Preferred Networks, Tokyo, Japan
CAT(Circular-convolutional ATtention,循环卷积注意力)是一种基于快速傅里叶变换(FFT)的次二次(sub-quadratic)自注意力机制。它通过将注意力图重新表述为循环矩阵(circulant matrix),并利用频域卷积定理将计算复杂度从标准注意力的 降低至 ,同时完整保留全局 softmax 加权结构。本文将系统介绍其数学基础、推导细节、PyTorch 实现、实验结果与在 各类注意力变体 中的定位。
1. 概述与动机
1.1 标准自注意力的 瓶颈
给定输入序列 ( 为序列长度, 为特征维度),标准自注意力(Self-Attention)执行:
其中 是 query/key 维度,softmax 沿行归一化,使得 的注意力矩阵每行求和为 。1
注意力的计算瓶颈在两个地方:
- 空间复杂度: 的注意力矩阵 占用 内存。
- 时间复杂度:两次矩阵乘法( 与 )均需 FLOPs。
当 达到 8K、32K 甚至更长(如长文档、视频、基因组)时, 在显存和算力上都成为不可承受的负担。这一瓶颈已被多种应用广泛证实。2
1.2 现有次二次注意力方法谱系
为了打破 屏障,研究者从不同角度切入,形成了三类主流范式:
| 范式 | 代表方法 | 核心思路 | 复杂度 | 代价 |
|---|---|---|---|---|
| 稀疏注意力 | Longformer、BigBird | 仅在局部窗口 + 少量全局 token 上计算 | 等 | 引入块大小等超参;破坏全局加权 |
| 核近似 / 线性注意力 | Linear Attention、Performer (FAVOR+) | 用核函数 把 softmax 拆为 | 丢失精确 softmax 形式,训练不稳定 | |
| 状态空间 / RNN 类 | S4、Mamba | 用结构化状态空间替换注意力 | 与注意力机制解耦,需重新设计训练栈 |
每一类都做出了自己的妥协:稀疏方法丢失全局性,核方法丢失 softmax 形式,SSM 类则与注意力彻底分道扬镳。CAT 提出了一个关键问题:
能否在保留全局 softmax 加权结构的同时,把复杂度压到 ?
1.3 CAT 的核心洞察
CAT 的答案源于一个反直觉的观察:标准的 注意力可以等价改写为一种圆卷积,只要 和 是同一序列的不同线性投影。进一步,若我们只用一个标量投影 (输出维度为 1)来产生循环注意力核,则整个注意力矩阵成为一个循环矩阵(circulant matrix):
而循环矩阵与向量的乘法正是一维圆卷积,由圆卷积定理可知可在 时间内通过 FFT/IFFT 完成。这便是 CAT 的核心数学洞察。3
下图为论文中从标准注意力到 CAT 的两种实现的对照:
┌─────────────────────┐ ┌──────────────────────┐ ┌──────────────────────┐
│ Self-Attention │ │ CAT (Gather) │ │ CAT (FFT) │
│ Q,K,V Linear │ │ Linear W_A, W_V │ │ Linear W_A, W_V │
│ Matmul QK^T │ => │ Softmax(Z) │ => │ FFT(Z*), FFT(V) │
│ Softmax │ │ Roll V via gather │ │ Hadamard 乘 │
│ Matmul · V │ │ (O(N^2) 但 cache 友好) │ │ IFFT │
│ ────────── │ │ │ │ O(N log N) │
│ NxN 注意力图 │ │ N 长度循环核 │ │ N 长度循环核 │
└─────────────────────┘ └──────────────────────┘ └──────────────────────┘
1.4 EIT 框架:工程同构 Transformer
CAT 论文还提出一个更上层的概念——Engineering-Isomorphic Transformers(EITs,工程同构 Transformer):满足以下四个条件的次二次注意力机制:
- Softmax 保持:输出可写成 ,与标准注意力形式同构。
- 次二次复杂度:严格小于 。
- 参数效率:可学习参数量与标准多头注意力相当或更少。
- 极简超参:不引入随序列长度变化的超参(如块大小、稀疏模式)。
CAT 是 EIT 框架下的一种具体实现,保留 softmax 同时复杂度降到 。4
2. 数学基础:循环卷积
2.1 离散圆卷积的定义
设 为两个长度为 的序列,离散圆卷积(circular convolution)定义为:
关键点是模运算 :这使得序列首尾相连、形成环面(torus)拓扑。下标越界时自动回绕到序列开头。
可以等价地写成矩阵-向量乘法形式:
其中 是由 的循环移位生成的 循环矩阵。
2.2 圆卷积定理(Circular Convolution Theorem)
圆卷积定理是 CAT 算法的理论引擎。它陈述:时域的圆卷积等于频域的逐元素乘积。
设 表示离散傅里叶变换(DFT), 为其逆变换。则:
其中 为 Hadamard(逐元素)乘积。等价地:
复杂度对比:
- 直接计算: 个点中每个需 次乘加,。
- FFT 实现:DFT/IDFT 各 ,加上逐元素乘积 ,总计 。
Cooley–Tukey FFT 在 时达到 ,是经典的高效算法。5
2.3 DFT 矩阵与正交性
点 DFT 的矩阵形式是:
是酉矩阵(差一个 归一化因子),。这意味着循环卷积对应的是正交基下的乘法,是数值稳定的线性变换。
更一般地,循环矩阵 可被 对角化:
这是后续在频域高效计算的关键代数恒等式。
2.4 与 群卷积的关系
循环卷积在群论视角下是有限循环群 上的卷积:
这一定义与上述离散形式完全一致。当 时,DFT 把这个群代数 分解为不可约表示(一维特征标 )的直和,圆卷积定理正是这一分解的代数表述。
2.5 与标准(线性)卷积的关系
圆卷积与线性卷积的关系是:
其中 在末尾零填充至长度 。如果序列在两端不”卷起来”,则需要零填充以避免回绕造成的伪影(见 §4.4 边界处理)。
3. CAT 的核心思想
3.1 Self-Attention 的循环重表述
回到自注意力。设 ,CAT 用一个单一的标量投影 把输入映射到一个 维向量:
再施加逐行(实为逐元素)softmax:
关键观察: 被视作一个循环矩阵 的第一行:
而 恰好是 与每一列 的一维圆卷积的批量形式。
3.2 QK^T 与循环卷积的等价性
为什么能用循环矩阵替代 ?我们可以从两个角度理解:
角度 1: 的特例
若 (即合并 query 和 key 投影),则 的第 元素为:
这是一个对称的双线性形式,但本身仍不是循环卷积。
角度 2:参数共享后的显式循环化
当我们只用一个标量投影 时,“QK^T”被替换为单标量 ,注意力矩阵变成 ——每一行都是第一行的循环移位,这就是循环矩阵。
论文中的数学描述十分精确:
Specifically, we learn a single projection matrix to map the input into a vector . We then apply a row-wise softmax, yielding . This serves as the first row of a circulant matrix , thus representing global pairwise interactions.3
与原始 attention 的等价性条件:
- 当序列本身是周期性的(或可视为环形数据,如长视频、循环语料)时,圆卷积准确。
- 当序列是有限线性(如一段文字)时,需要在序列两端做适当处理,详见 §4.4。
3.3 softmax--V 的频域高效计算
将 softmax 后的循环注意力矩阵乘以 ,结合圆卷积定理得到:
算法伪代码:
Algorithm CAT_forward(X, W_A, W_V)
Input: X ∈ R^{N×D}, W_A ∈ R^{D×1}, W_V ∈ R^{D×D}
Output: O ∈ R^{N×D}
1. Z ← X @ W_A // [N, 1]
2. Z* ← softmax(Z, dim=0) // [N, 1], row-wise softmax
3. V ← X @ W_V // [N, D]
4. Ỹ ← FFT(V, dim=0) // [N, D] (复数)
5. Ẑ ← FFT(Z*, dim=0) // [N] (复数)
6. Ŷ ← Ẑ ⊙ Ỹ // [N, D] (复数, Hadamard)
7. O ← IFFT(Ŷ, dim=0).real // [N, D]
8. return O复杂度分析(对长度 维度):
| 操作 | 复杂度 | 备注 |
|---|---|---|
Z = X @ W_A | 仅 投影 | |
softmax(Z) | ||
V = X @ W_V | 标准值投影 | |
FFT(V) | 每列独立 FFT | |
FFT(Z*) | ||
Hadamard | ||
IFFT | ||
| 总 | 对 是次二次 |
相比标准注意力的 ,CAT 把对 的依赖从 降到 。
3.4 多头扩展
CAT 天然支持多头注意力:每头维护独立的 , 个头的输出拼接后过输出投影:
其中 。这与标准 MHA 形式完全一致,且每头都是 。
4. 完整数学推导
4.1 从标准注意力到 CAT
设 。标准多头注意力的单头输出为:
步骤 1:合并 query 与 key。
令 ,并定义:
论文中称此为 Averaged-Key 参数化。我们以更简洁的合并形式()表达。
步骤 2:循环化。
将 视作循环矩阵 的生成向量。形式上:
步骤 3:频域化。
利用圆卷积定理:
记 ,,其中 沿序列维度(dim=0)作用。
4.2 FFT/IFFT 细节
在 PyTorch 中实数序列的 FFT 形式为:
torch.fft.rfft(x, dim=0):实数输入,输出 个复数(利用共轭对称性,节省一半存储)。torch.fft.irfft(y, n=N, dim=0):从 个复数恢复 个实数。
对于复数乘法,PyTorch 自动广播。关键的精度保证是:实数输入 + 圆卷积 + 复数 FFT → 频域乘积在数值上与实数域卷积结果在浮点精度内一致(无离散化误差)。
4.3 数值稳定性与 NaN 防护
实现中需要处理以下数值问题:
- Softmax 溢出:当 很大时 会溢出。标准做法是减去最大值:
- FFT 中的 与 NaN:若输入含 NaN,FFT 会传播 NaN。务必在前向计算前做
torch.nan_to_num或断言。 - rfft 输入为复数时的实部/虚部拼接:论文实现中常将 与 视为”复数”形式以减少一次 IFFT。一种实现是把 沿最后一维复制到 的”实部”,或者直接对 单独做 rfft。
- 复数张量类型:使用
torch.complex64(与float32配套)或torch.complex128。在 GPU 上complex64的实部/虚部分别占用float32内存,但 cuFFT 需要专门的 kernel。
4.4 边界处理
循环卷积假设序列在两端首尾相连。当处理线性序列时,需要补零(zero-padding):
- 将 与 各填充至 长度:、。
- FFT 后做 Hadamard 乘积,IFFT 取前 项即为线性卷积(无回绕)。
但这会使长度变 ,FFT 成本增至 。实践中,论文作者发现直接圆卷积(不做补零)反而效果更好——可能因为:
- 真实序列(如图像 patch、词序列)天然存在弱周期性。
- 强制首尾不连接会引入边界伪影。
- 论文 §5.2 的实验显示圆卷积在 ViT avg pooling 下优于标准 attention。
下表总结了边界处理策略的取舍:
| 策略 | 复杂度 | 边界行为 | 适用场景 |
|---|---|---|---|
| 纯圆卷积 | 首尾相连、可能回绕 | 图像 patch、视频、循环语料 | |
| 零填充到 | 等价线性卷积 | 严格线性序列 | |
| 因果移位 | 强制因果 | 自回归 LM(见 §7) |
4.5 与 Linear Attention 的对比数学
Linear Attention 用核函数 把 近似为 。计算顺序倒置后:
复杂度 ,对 是线性。但代价是:
- 仅为 softmax 的低秩近似(FAVOR+ 用正/负随机特征),无法精确表达 softmax 形式。
- 在长序列上训练不稳定、精度下降。
CAT 与 Linear Attention 的核心区别:
| 维度 | Linear Attention | CAT |
|---|---|---|
| 复杂度 | ||
| Softmax 形式 | 仅近似 | 精确保持 |
| 矩阵结构 | 循环矩阵 | |
| 训练稳定性 | 在大模型上常 NaN | 报告稳定 |
| 全局感受野 | 全局(核近似) | 全局(精确) |
可以视 CAT 为保持 softmax 的次二次折中——比 Linear Attention 多一个 因子,但换来精确性。
4.6 与 Low-Rank 注意力压缩的关系
本仓库的 注意力矩阵低秩压缩方法假设注意力矩阵 存在低秩结构 。CAT 的视角正交:CAT 不是近似 ,而是直接构造一个结构化(循环)。
联系 SVD 与谱方法 角度看,循环矩阵的特征分解:
意味着 CAT 在 DFT 基下只有 个自由度(即 的 个复数),远少于一般 矩阵的 个自由度。这与低秩方法在”减少注意力参数”上有共同目标,但路径不同:低秩通过 SVD 截断,CAT 通过结构对称性。
5. 实现细节
5.1 基础循环卷积 attention 实现
下面是一个最简的 PyTorch 实现,用于教学目的。Gather 版本通过显式构造循环矩阵、避免大矩阵乘法而实测比 naive matmul 快约 10%。
import torch
import torch.nn as nn
import torch.nn.functional as F
def circular_shift_matrix(z_star: torch.Tensor) -> torch.Tensor:
"""
根据首行 z_star 构造 N×N 循环矩阵。
Input: z_star [N, 1]
Output: M [N, N] M[i, j] = z_star[(j - i) mod N]
"""
N = z_star.shape[0]
idx = (torch.arange(N, device=z_star.device).unsqueeze(0)
- torch.arange(N, device=z_star.device).unsqueeze(1)) % N
return z_star[idx] # [N, N, 1] -> squeeze 最后一维
class CATGather(nn.Module):
"""
Gather 版本的 CAT (O(N^2) 但 cache 友好)。
论文报告在 ViT CLIP-L 上对标准 attention 提速约 10%。
"""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
# 合并的 query-key 投影, 每个 head 一列
self.W_A = nn.Parameter(torch.randn(n_heads, d_model) * 0.02)
# value 投影
self.W_V = nn.Linear(d_model, d_model, bias=False)
# 输出投影
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [B, N, D]
return: [B, N, D]
"""
B, N, D = x.shape
H, Dh = self.n_heads, self.d_head
# 1) Z = X @ W_A: 对每个 head 投影到标量
# x: [B, N, D] @ W_A^T: [D, H] -> [B, N, H]
Z = x @ self.W_A.t() # [B, N, H]
# softmax 沿 N 维度
Z_star = F.softmax(Z, dim=1) # [B, N, H]
# 2) V = X @ W_V: 标准值投影
V = self.W_V(x) # [B, N, D]
V = V.view(B, N, H, Dh).transpose(1, 2) # [B, H, N, Dh]
# 3) 构造循环矩阵并与 V 相乘
# 利用圆卷积 = 用 roll 构造移位 + 加权求和
# O_[:, :, i, :] = sum_m Z_star[:, (i-m) mod N, :] * V[:, :, m, :]
# 高效实现: 用 unfold 或 gather
out_heads = []
for h in range(H):
Zh = Z_star[:, :, h] # [B, N]
Vh = V[:, h] # [B, N, Dh]
# 循环移位所有行后加权求和
# shift_k: 沿 N 维 shift k 位
acc = torch.zeros_like(Vh)
for k in range(N):
acc = acc + Zh[:, k:k+1] * torch.roll(Vh, shifts=k, dims=1)
out_heads.append(acc)
out = torch.stack(out_heads, dim=1) # [B, H, N, Dh]
out = out.transpose(1, 2).contiguous().view(B, N, D)
return self.W_O(out)注意上面的内层循环是 但容易理解。下面用 FFT 把它降到 。
5.2 高效 FFT 版本
class CATFFT(nn.Module):
"""
完整 FFT 版本的 CAT,O(N log N) 序列复杂度。
使用 torch.fft.rfft 处理实数输入。
"""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_A = nn.Parameter(torch.randn(n_heads, d_model) * 0.02)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [B, N, D]
return: [B, N, D]
"""
B, N, D = x.shape
H, Dh = self.n_heads, self.d_head
# 1) 投影得到 Z (合并 QK) 和 V
Z = x @ self.W_A.t() # [B, N, H]
Z_star = F.softmax(Z, dim=1) # [B, N, H]
V = self.W_V(x).view(B, N, H, Dh) # [B, N, H, Dh]
# 2) 把 Z_star 复数化以便 broadcast 到 Dh 维
# 方式 A: 用 complex tensor, H 维与 V 的 Dh 维对齐
# 复数乘法 a * b = (a_re*b_re - a_im*b_im) + i(a_re*b_im + a_im*b_re)
# 此处我们把 Z_star 当作实部, 虚部为 0
Z_complex = torch.complex(Z_star, torch.zeros_like(Z_star)) # [B, N, H]
# 3) 沿 N 维做 FFT
# rfft 对最后一维? 不, 我们对 dim=1 (N 维) 做
# torch.fft.fft 在指定 dim 上做, 返回完整 N 个复数
Z_hat = torch.fft.fft(Z_complex, dim=1) # [B, N, H]
V_hat = torch.fft.fft(V, dim=1) # [B, N, H, Dh], complex
# 4) Hadamard 乘积: Z_hat 与 V_hat 沿 N 维逐元素乘
# 需要 broadcast: Z_hat [B, N, H, 1] * V_hat [B, N, H, Dh]
Z_hat_expanded = Z_hat.unsqueeze(-1) # [B, N, H, 1]
Y_hat = Z_hat_expanded * V_hat # [B, N, H, Dh], complex
# 5) IFFT 回到实数
Y = torch.fft.ifft(Y_hat, dim=1).real # [B, N, H, Dh]
# 6) 拼 head + 输出投影
out = Y.contiguous().view(B, N, D)
return self.W_O(out)复杂度分析:
softmax+W_A @ X: FLOPsW_V @ X:FFT(Z_star): 复数 FLOPsFFT(V):- Hadamard:
IFFT:同 FFT
总复杂度 ,对 是次二次。
5.3 与标准注意力的等价性测试
下面脚本测试:在退化情形下(无 softmax 偏移、长度 较小),CAT 的循环实现与”展开 “形式的差异应该仅来自浮点精度。
import torch
import torch.nn.functional as F
from typing import Tuple
def standard_attention_for_test(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
"""
标准 attention, 用于对照。
Q, K, V: [B, N, D]
"""
scale = Q.shape[-1] ** -0.5
return F.softmax(Q @ K.transpose(-2, -1) * scale, dim=-1) @ V
def cat_circulant_attention(Z_star: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
"""
显式构造循环矩阵的 CAT, 用于验证.
Z_star: [B, N, 1]
V: [B, N, D]
"""
B, N, _ = Z_star.shape
D = V.shape[-1]
# 构造循环矩阵, M[b, i, j] = Z_star[b, (j-i) mod N, 0]
idx = (torch.arange(N).unsqueeze(0) - torch.arange(N).unsqueeze(1)) % N # [N, N]
M = Z_star[:, idx, 0] # [B, N, N]
return M @ V # [B, N, D]
def test_equivalence():
torch.manual_seed(0)
B, N, D = 2, 8, 16
# 构造 Q, K, V
X = torch.randn(B, N, D)
W_Q = torch.randn(D, D) * 0.1
W_K = torch.randn(D, D) * 0.1
W_V = torch.randn(D, D) * 0.1
Q, K, V = X @ W_Q, X @ W_K, X @ W_V
# 构造 "circulant 化" 的 Q, K: 让 Q = K (合并) 并降维
# 这时 QK^T 是秩 1 矩阵, 但我们用循环化近似
# 注意: 这只是"概念"测试, 实际 CAT 不直接等于标准 attention
Z = (Q + K) / 2 # 模拟 query+key 合并
Z_star = F.softmax(Z.sum(dim=-1, keepdim=True), dim=1) # [B, N, 1]
out_standard = standard_attention_for_test(Q, K, V)
out_circulant = cat_circulant_attention(Z_star, V)
# 两者并不相等 (结构本就不同), 仅 sanity check shapes
assert out_standard.shape == out_circulant.shape
diff = (out_standard - out_circulant).norm() / out_standard.norm()
print(f"relative L2 diff (expected non-zero): {diff.item():.4f}")
def test_fft_correctness():
"""验证 cat_fft 与 cat_circulant 在数值上等价."""
torch.manual_seed(1)
B, N, D = 1, 16, 32
Z_star = F.softmax(torch.randn(B, N, 1), dim=1)
V = torch.randn(B, N, D)
# 循环矩阵路径
out_circ = cat_circulant_attention(Z_star, V)
# FFT 路径
Z_complex = torch.complex(Z_star, torch.zeros_like(Z_star))
Z_hat = torch.fft.fft(Z_complex, dim=1).unsqueeze(-1) # [B, N, 1, 1]
V_hat = torch.fft.fft(V, dim=1).unsqueeze(2) # [B, N, 1, D]
Y_hat = Z_hat * V_hat
out_fft = torch.fft.ifft(Y_hat, dim=1).real.squeeze(2) # [B, N, D]
err = (out_circ - out_fft).abs().max().item()
print(f"FFT vs Circulant max abs error: {err:.2e} (should be ~1e-6)")
if __name__ == "__main__":
test_equivalence()
test_fft_correctness()预期输出类似:
relative L2 diff (expected non-zero): 0.43xxx
FFT vs Circulant max abs error: 1.19e-07 (should be ~1e-6)
第二个测试应该显示 FFT 路径与显式循环矩阵路径完全一致(仅浮点误差),这是 §2 圆卷积定理的直接推论。
5.4 Triton/CUDA 优化思路
虽然 PyTorch 的 torch.fft.fft 已经调用了 cuFFT,但仍有几个优化空间:
- 融合的 Hadamard + IFFT kernel:把乘积与 IFFT 合并到一个 Triton kernel 中,避免中间结果写入显存。
- 实数 FFT 的 half-precision:在 A100/H100 上,
torch.fft.rfft支持float16输入,但需注意数值范围。 - 多 head 共享一次 FFT:对 而言,多个 head 共享同一份实数序列的 FFT 结果。
- 与 FlashAttention 兼容:CAT 的核心是圆卷积而不是 风格,难以直接复用 FlashAttention 的 tiling。但对于输出 的部分,可以借鉴 FlashAttention-2 的在线 softmax 思想在频域重写。
论文 §C 报告了与 FlashAttention 兼容性的初步结果:在 ViT 上可以混合 CAT 层与标准 attention 层(CAT-Alter),总速度可超过纯 FlashAttention 的 baseline。
5.5 内存占用对比
| 方法 | 中间张量 | 峰值显存 |
|---|---|---|
| Standard Attention | ||
| Linear Attention | ||
| CAT (FFT) |
CAT 完全没有 项,是真正的次二次显存使用。这对超长序列推理尤其关键。
6. 实验结果
6.1 视觉任务(ImageNet-1k)
论文采用 ViT CLIP-B(12 heads)与 CLIP-L(16 heads),分别在 token pooling([CLS] token)与 avg pooling(全局平均)下做 ImageNet-1k 分类,结果如下(节选自论文 Tab. 1):
| Model | Pool | Mechanism | Learnable | Complexity | Acc.↑ |
|---|---|---|---|---|---|
| CLIP-B | token | Attention | 0.574 | ||
| CLIP-B | token | CAT | 0.540 | ||
| CLIP-B | token | CAT-Alter | 0.582 | ||
| CLIP-L | token | Attention | 0.574 | ||
| CLIP-L | token | CAT | 0.559 | ||
| CLIP-L | token | CAT-Alter | 0.593 | ||
| CLIP-B | avg | Attention | 0.638 | ||
| CLIP-B | avg | CAT | 0.649 | ||
| CLIP-B | avg | CAT-Alter | 0.662 | ||
| CLIP-L | avg | Attention | 0.646 | ||
| CLIP-L | avg | CAT | 0.694 | ||
| CLIP-L | avg | CAT-Alter | 0.681 |
关键发现:
- avg pooling > token pooling:在所有方法中,avg pooling 一致优于 [CLS] token pooling。这与近期 ViT 研究的趋势一致(如 SepViT、Patch Merger)。
- CAT 在 avg pooling 下超过标准 attention:CLIP-L 上 CAT (0.694) vs Attention (0.646),提升 +4.8 个百分点。
- CAT-Alter 是最稳健的折中:在 token pooling 下 CAT-Alter 也能反超 attention。
6.2 语言任务(WikiText-103)
论文使用 Transformer-XL(10 heads)与 GPT-2 small(12 heads),分别在 masked 与 causal 语言建模下评估,指标为 word perplexity(越低越好):
| Model | LM Type | Mechanism | Complexity | Word PPL↓ |
|---|---|---|---|---|
| Transformer-XL | masked | Attention | 13.94 | |
| Transformer-XL | masked | CAT | 10.28 | |
| Transformer-XL | masked | CAT-Alter | 8.51 | |
| GPT-2 small | masked | Attention | 9.82 | |
| GPT-2 small | masked | CAT | 8.32 | |
| GPT-2 small | masked | CAT-Alter | 7.54 | |
| Transformer-XL | causal | Attention | 30.82 | |
| Transformer-XL | causal | CAT | 36.71 | |
| Transformer-XL | causal | CAT-Alter | 30.93 | |
| GPT-2 small | causal | Attention | 27.84 | |
| GPT-2 small | causal | CAT | 32.36 | |
| GPT-2 small | causal | CAT-Alter | 27.68 |
关键发现:
- masked LM 中 CAT 显著优于 Attention:Transformer-XL 上 PPL 从 13.94 降到 10.28(-26%),GPT-2 small 上从 9.82 降到 8.32(-15%)。
- causal LM 中纯 CAT 退化:从 30.82 退化到 36.71。原因是强制因果性会破坏循环对称性(见 §7.2 边界讨论),导致需要显式 mask,复杂度退化到 。
- CAT-Alter 在因果 LM 上也能与 attention 持平:这是工程上最实用的方案。
6.3 与 Linear / Sparse Attention 的对比
论文 §5.3 报告了与 Linear Attention(FAVOR+ / Katharopoulos 形式)和稀疏注意力的对比尝试:
- Linear Attention (CLIP-L):作者在 ImageNet-1k 上反复遇到 NaN 损失,无法稳定收敛。这与文献中报道的”kernel-based attention 在大模型上训练不稳定”一致。
- 稀疏注意力 (BigBird / Longformer):作者认为其引入了序列长度依赖的超参数(块大小、窗口宽度),违反 EIT 框架的”极简超参”原则,因此未做主表对比。
这凸显 CAT 的一个核心优势:保持 softmax 形式 → 训练稳定 + 少超参。
6.4 推理速度 Benchmark
论文报告的关键速度数据(ViT CLIP-L,N=256,NVIDIA V100,naive PyTorch):
| 方法 | 单次迭代时间 | 加速比 |
|---|---|---|
| Standard Attention | 1.00×(基线) | 1.00 |
| CAT (Gather) | 0.90× | 1.11× |
| CAT (FFT) | 1.00× | 1.00(无明显加速) |
解读:
- Gather 版本实测加速 10%——这看似违反”FFT 才是 “的论断。原因: 时 FFT 调用与 kernel launch 开销已经接近 matmul 开销,FFT 优势要 才显现。
- 在更大的 或专用 FFT kernel 下,FFT 版本的理论优势会更明显(论文留作未来工作)。
- CAT-Alter 加速介于两者之间(因为一半层是标准 attention)。
7. 理论分析
7.1 表达力:能否逼近任意 attention?
CAT 的表达力受限于两个结构性约束:
- 标量投影 只能生成 个标量权重,而标准 attention 的 是 的双线性。
- 循环对称性强制每一行是首行的循环移位,无法表达”非平移不变”的注意力模式。
因此 CAT 不能逼近任意 attention。形式上:
而标准 attention 矩阵的秩可以高达 。这意味着 CAT 的”全局 softmax”是由循环结构约束的全局性,而非任意配对权重。
但实际效果可与 Attention 相当甚至更好,原因可能在于:
- 训练数据中”长程关联” 本身常表现为移位不变模式(如位置编码、卷积特征)。
- 循环结构作为强归纳偏置,相当于正则化,减少过拟合。
- 与 注意力谱理论 中的”rank collapse”现象不同,CAT 通过循环结构主动控制秩。
7.2 局部性:圆卷积天然编码环形局部结构
圆卷积最直接的属性是局部平移等变性(local shift-equivariance):
这意味着如果输入序列发生循环平移,CAT 的输出也按相同方式平移——这一点与离散平移群 的作用自然兼容。
在早期 Transformer 层的注意力图常表现出”局部条带 + 少量全局”模式(BERT 早期层可视化,Clark et al. 2019)。CAT 通过循环结构直接编码这种局部性,无需学习。这是它在小数据 / 早期层表现优异的原因之一。
7.3 与正交变换的连接(DFT 矩阵)
DFT 矩阵 是酉矩阵,因此:
循环矩阵在 下对角化:
这说明 CAT 在频域做的是对角缩放(Hadamard 乘积的频域形式),是一种正交保持的线性变换。这与 SVD 与谱方法 中的”在正交基下做对角操作”思路一脉相承——CAT 把注意力矩阵的对角化显式构造为 DFT。
7.4 不同序列长度下的缩放行为
对长度为 的序列,CAT 的缩放关系为:
| 组件 | 复杂度 |
|---|---|
| 投影() | |
| softmax | |
| FFT(每列) | |
| IFFT(每列) | |
| 总 |
注意 与 是分离的:CAT 在固定 下对 是 。
与各类方法的渐近对比:
| 方法 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| Standard Attention | ||
| Linear Attention | ||
| Sparse (Longformer) | ( = 窗口) | |
| State Space (Mamba) | ||
| CAT | ||
| CAT-Alter(混合) |
当 极大()时,CAT 的 比 Linear / SSM 的 多一个 因子,但换来精确的 softmax 形式——这是经典的时间-精度折中。
7.5 与 Kernel Method 的统一视角
注意力的”kernel view”把 视作一个 Mercer 核:
Linear Attention 用随机 Fourier 特征 近似 :
CAT 的视角更结构化:它把 限制为循环形式(即 ),用循环对称性换取显式实现。
可以认为 CAT 是 kernel approximation with structural prior on shift-invariance。这种结构先验在数据本身具有平移结构时是合适的(图像、语音、时间序列),但在 NLP 等更抽象的领域可能成为限制——这或许是 CAT 在 WikiText-103 causal LM 上退化的原因。
8. 相关工作
CAT 处于次二次注意力研究的核心位置。下表对主流方法做横向比较:
| 方法 | 年份 | 复杂度 | Softmax 形式 | 全局性 | 参数量 |
|---|---|---|---|---|---|
| Standard Attention | 2017 | 精确 | 全局 | ||
| Sparse Transformer | 2019 | 精确(局部) | 局部 + 少量全局 | ||
| Longformer | 2020 | 精确(局部) | 局部 + 全局 token | ||
| BigBird | 2020 | 精确(局部) | 局部 + 随机 + 全局 | ||
| Linear Attention | 2020 | 近似 | 全局(核近似) | ||
| Performer (FAVOR+) | 2021 | 近似 | 全局(核近似) | ||
| Hyena | 2023 | 无 | 全局(卷积) | ||
| RetNet | 2023 | 隐式 | 全局 | ||
| Mamba | 2023 | 无 | 全局 | ||
| Monarch Mixer | 2023 | 无 | 块结构 | 块稀疏 | |
| CAT (本篇) | 2025 | 精确 | 全局 |
8.1 FAVOR+ (Performer)
Choromanski et al. (2021) 用正/负随机 Fourier 特征近似 softmax 核:
实现 复杂度。CAT 不做核近似,保留精确 softmax,是另一条路径。
8.2 Hyena Operator
Poli et al. (2023) 提出的 Hyena 用长卷积 + 门控替代注意力,本质是”无 softmax 的次二次模型”。与 CAT 的对比:
- Hyena 完全无 softmax,通过隐式参数化学习长程依赖。
- CAT 保留 softmax 但用循环结构约束矩阵形式。
Hyena 的复杂度 与 CAT 相同,但表达力假设不同:Hyena 假设数据有强卷积结构,CAT 假设数据有强平移不变性。
8.3 Monarch Mixer
Dao et al. (2023) 的 Monarch Mixer 用块对角 + FFT 的混合矩阵做 token 混合,是另一种结构化矩阵路径。与 CAT 的循环矩阵相比,Monarch 更灵活(块大小可调)但没有 softmax 形式。
8.4 与本仓库其他文献的交叉
CAT 的思想在多个维度上与本仓库已有内容呼应:
- 线性注意力机制理论:与 Linear Attention 的核方法路径形成对照,可视作”保留 softmax 的次二次路径”。
- 线性注意力新变体(2024-2025):2024-2025 年的最新进展(KDA、ZeroS、RACE、Log-Linear)大多沿”Linear + 改进核”路径;CAT 走的是截然不同的”循环 + FFT”路径。
- 注意力矩阵低秩压缩:低秩方法通过 SVD 截断减少参数,CAT 通过循环对称性减少参数;两者都关心”用更少参数表达 ”。
- SVD 与谱方法:CAT 在 DFT 基下做对角化是 SVD 思路的变体。
- Flash Attention:与 FlashAttention 的”减少 IO 次数”思路正交,CAT 是”减少理论 FLOPs”。
- 注意力变体比较:在更宏观的注意力机制谱系中,CAT 属于”结构化矩阵 + 精确 softmax”派。
8.5 近期相关进展
- Spectformer (2023):用频域混合做全局 token 混合,与 CAT 有相似的动机。
- FourierFormer (2022):用 Fourier 变换参数化 attention,与 CAT 共享 FFT 思想但保留 形式。
- GFNet (2021):在频域做 token 混合,无 softmax。
9. 实践指南
9.1 何时选用 CAT
适合 CAT 的场景:
- ✅ 图像任务(ViT 类):patch 序列天然有”局部 + 弱周期”结构,CAT 的循环归纳偏置吻合。
- ✅ 中等长度序列():Gather 版本就有 10% 加速。
- ✅ 超长序列推理: 显存占用是关键优势。
- ✅ Masked LM:实验显示 CAT 在 masked 设定下显著优于标准 attention。
- ✅ 显存受限训练:相比 Linear Attention 没有 NaN 问题。
不适合 CAT 的场景:
- ❌ 纯自回归 LM(causal-only):循环对称性被破坏,需要显式 mask 重新退化为 。
- ❌ 需要任意配对交互的任务:如跨句指代、复杂结构推理。
- ❌ 极小模型():参数效率优势不显著,归纳偏置反而成为限制。
- ❌ 对绝对位置敏感的任务:CAT 偏向相对/平移结构。
9.2 常见陷阱
- 初始化: 初始化过大会导致 softmax 饱和, 近似 one-hot,圆卷积退化为”取一个特定移位的 V”。建议小初始化()。
- 学习率:CAT 报告在原论文设置下无需特殊调整。但若混入预训练权重, 的初始梯度可能很小,需要 warmup。
- 批大小:FFT 在大 batch + 中等 下非常高效;如果 batch=1,可能需要把多序列打包。
- 混合精度:
torch.fft.fft在float16下可能损失精度,建议Z_star保持float32,只在 IFFT 阶段转回float16。 - Padding 长度:若序列需要 pad 到 batch 内最大长度,循环卷积会”卷”到 pad 区域。最佳做法是按真实长度 FFT 后再 mask。
9.3 与预训练权重的兼容性
CAT 的参数空间与标准 attention 部分重叠:
- :可直接复用预训练的 value 投影。
- :新参数,需从头训练。可用 初始化( 是单位向量)来接近预训练分布。
- :可直接复用预训练的输出投影。
微调流程建议:
- 加载预训练标准 attention 权重到 、。
- 用 的均值初始化 。
- 用较小学习率(如基础 LR 的 0.1×)微调 。
- 全模型用正常 LR 微调 1-2 epoch。
9.4 训练稳定性建议
- 梯度裁剪:CAT 的频域操作对大幅度梯度敏感,建议
clip_grad_norm_(parameters, 1.0)。 - 学习率 warmup:5-10% 步数的 warmup 能稳定早期 的学习。
- Layer-scale:若与 [Touvron et al. 2021] 风格的可学习 layer-scale 结合,建议初始化为 0.1 而非 1.0,因为循环结构天然放大输出幅度。
- Dropout:在 上加 dropout 比在 上加更稀疏(因为 个标量 vs 个元素)。建议保持 dropout 概率 0.1。
9.5 推荐的工程设置
| 组件 | 推荐值 |
|---|---|
| 初始化 | |
| 初始化 | Xavier / Kaiming |
| 学习率(CAT-only) | 与 baseline 相同 |
| 学习率(CAT-Alter) | 与 baseline 相同 |
| FFT 长度 | (效率)或 (精确) |
| Dropout() | 0.1 |
| 梯度裁剪 | 1.0 |
| Optimizer | AdamW() |
10. 局限性、争议与开放问题
10.1 论文自陈局限
论文 §7 明确指出:
- 极长序列尚未验证:CAT 在 8K+ 序列上的稳定性和精度仍是开放问题。
- 硬件优化的 FFT kernel 尚未实现:当前 naive PyTorch 实现在 下 FFT 版本没有明显加速。
- 因果性破坏循环结构:causal LM 下纯 CAT 退化为 。
- 表达力受限:循环结构本身是对注意力矩阵的强约束。
10.2 社区视角的争议
- “循环 = 卷积 = 失去灵活性”:部分研究者认为循环结构过于限制,难以表达现代 LLM 需要的复杂关系。
- “10% 加速不够颠覆”:相比 Performer、Mamba 等 方法,CAT 的 在工程上没有压倒性优势。
- “不兼容 KV cache”:循环结构使得标准的增量推理 (KV cache) 难以直接应用;论文作者通过重写 Z 的偏移来部分支持,但增加了复杂度。
10.3 开放问题
- 混合架构的最优比例:CAT-Alter 在论文中显示混合 50% 是好的,但更系统的搜索(如 Hymba 的 head-level 混合)可能更好。
- 与 SSM 的统一:CAT 保留了 softmax,SSM 没有——能否在频域统一两者的优点?
- 多尺度 CAT:当前 CAT 用单一循环核 。能否做多尺度(多个 拼接,类似 multi-head 但每个 head 不同尺度)?
- 可解释性:循环注意力矩阵的特征值就是 (DFT 后),这给了频谱可解释性——能否用于模型分析?
11. 总结
CAT 提供了一个在保持 softmax 形式的前提下实现 自注意力的优雅方案。其核心思想——把注意力矩阵重写为循环矩阵、利用圆卷积定理在频域计算——是结构化矩阵方法与深度学习结合的典范。
核心要点:
- 数学洞察:QK^T 注意力可重写为圆卷积 。
- 算法核心:用 再 IFFT,复杂度 。
- EIT 框架:次二次 + 精确 softmax + 少超参。
- 实验优势:在 ImageNet avg pooling、WikiText masked LM 上优于标准 attention。
- 工程友好:naive PyTorch 即可实现,10% 加速无特殊 kernel 要求。
在 注意力变体谱系 中,CAT 占据一个独特位置:它不是”更激进的近似”(如 Linear / Sparse),而是”更结构化的精确”——用对称性换效率。
未来的研究可能沿两个方向展开:
- 算法侧:与 FlashAttention、SSM、KV cache 的深度融合。
- 应用侧:在长视频理解、基因组学、代码生成等长序列任务上的实际部署。
12. 参考文献
Footnotes
-
Vaswani et al. Attention is All You Need. NeurIPS 2017. 标准自注意力的原始定义,本文 §1.1 的公式直接沿用。 ↩
-
Tay et al. Efficient Transformers: A Survey. ACM Computing Surveys 2022. 系统综述了 注意力在长序列上的瓶颈及各类次二次方案。 ↩
-
Yamada, Y. CAT: Circular-Convolutional Attention for Sub-Quadratic Transformers. NeurIPS 2025, arXiv:2504.06704. 本文核心引用,所有”循环化”与”圆卷积”的数学细节均来自该论文 §4。 ↩ ↩2
-
Yamada 2025, §2. EIT(Engineering-Isomorphic Transformer)框架的四条公理:softmax 保持、次二次、参数高效、极简超参。 ↩
-
Cooley & Tukey. An Algorithm for the Machine Calculation of Complex Fourier Series. Mathematics of Computation 1965. FFT 的原始论文,本文 §2.2 的复杂度结论基于此。 ↩