Circular-Convolutional Attention(CAT)

核心论文:Yoshihiro Yamada. CAT: Circular-Convolutional Attention for Sub-Quadratic Transformers. NeurIPS 2025.
arXiv2504.06704 | OpenReviewq7hoTSbV1t
作者机构:Preferred Networks, Tokyo, Japan

CAT(Circular-convolutional ATtention,循环卷积注意力)是一种基于快速傅里叶变换(FFT)的次二次(sub-quadratic)自注意力机制。它通过将注意力图重新表述为循环矩阵(circulant matrix),并利用频域卷积定理将计算复杂度从标准注意力的 降低至 ,同时完整保留全局 softmax 加权结构。本文将系统介绍其数学基础、推导细节、PyTorch 实现、实验结果与在 各类注意力变体 中的定位。


1. 概述与动机

1.1 标准自注意力的 瓶颈

给定输入序列 为序列长度, 为特征维度),标准自注意力(Self-Attention)执行:

其中 是 query/key 维度,softmax 沿行归一化,使得 的注意力矩阵每行求和为 1

注意力的计算瓶颈在两个地方

  1. 空间复杂度 的注意力矩阵 占用 内存。
  2. 时间复杂度:两次矩阵乘法()均需 FLOPs。

达到 8K、32K 甚至更长(如长文档、视频、基因组)时, 在显存和算力上都成为不可承受的负担。这一瓶颈已被多种应用广泛证实。2

1.2 现有次二次注意力方法谱系

为了打破 屏障,研究者从不同角度切入,形成了三类主流范式:

范式代表方法核心思路复杂度代价
稀疏注意力Longformer、BigBird仅在局部窗口 + 少量全局 token 上计算引入块大小等超参;破坏全局加权
核近似 / 线性注意力Linear Attention、Performer (FAVOR+)用核函数 把 softmax 拆为 丢失精确 softmax 形式,训练不稳定
状态空间 / RNN 类S4、Mamba用结构化状态空间替换注意力与注意力机制解耦,需重新设计训练栈

每一类都做出了自己的妥协:稀疏方法丢失全局性,核方法丢失 softmax 形式,SSM 类则与注意力彻底分道扬镳。CAT 提出了一个关键问题:

能否在保留全局 softmax 加权结构的同时,把复杂度压到

1.3 CAT 的核心洞察

CAT 的答案源于一个反直觉的观察:标准的 注意力可以等价改写为一种圆卷积,只要 是同一序列的不同线性投影。进一步,若我们只用一个标量投影 (输出维度为 1)来产生循环注意力核,则整个注意力矩阵成为一个循环矩阵(circulant matrix):

而循环矩阵与向量的乘法正是一维圆卷积,由圆卷积定理可知可在 时间内通过 FFT/IFFT 完成。这便是 CAT 的核心数学洞察。3

下图为论文中从标准注意力到 CAT 的两种实现的对照:

┌─────────────────────┐    ┌──────────────────────┐    ┌──────────────────────┐
│  Self-Attention     │    │  CAT (Gather)        │    │  CAT (FFT)           │
│  Q,K,V Linear       │    │  Linear W_A, W_V     │    │  Linear W_A, W_V     │
│  Matmul QK^T        │ => │  Softmax(Z)          │ => │  FFT(Z*), FFT(V)     │
│  Softmax            │    │  Roll V via gather   │    │  Hadamard 乘         │
│  Matmul · V         │    │  (O(N^2) 但 cache 友好) │  │  IFFT                │
│  ──────────         │    │                      │    │  O(N log N)          │
│  NxN 注意力图       │    │  N 长度循环核         │    │  N 长度循环核         │
└─────────────────────┘    └──────────────────────┘    └──────────────────────┘

1.4 EIT 框架:工程同构 Transformer

CAT 论文还提出一个更上层的概念——Engineering-Isomorphic Transformers(EITs,工程同构 Transformer):满足以下四个条件的次二次注意力机制:

  1. Softmax 保持:输出可写成 ,与标准注意力形式同构。
  2. 次二次复杂度:严格小于
  3. 参数效率:可学习参数量与标准多头注意力相当或更少。
  4. 极简超参:不引入随序列长度变化的超参(如块大小、稀疏模式)。

CAT 是 EIT 框架下的一种具体实现,保留 softmax 同时复杂度降到 4


2. 数学基础:循环卷积

2.1 离散圆卷积的定义

为两个长度为 的序列,离散圆卷积(circular convolution)定义为:

关键点是模运算 :这使得序列首尾相连、形成环面(torus)拓扑。下标越界时自动回绕到序列开头。

可以等价地写成矩阵-向量乘法形式:

其中 是由 的循环移位生成的 循环矩阵。

2.2 圆卷积定理(Circular Convolution Theorem)

圆卷积定理是 CAT 算法的理论引擎。它陈述:时域的圆卷积等于频域的逐元素乘积

表示离散傅里叶变换(DFT), 为其逆变换。则:

其中 为 Hadamard(逐元素)乘积。等价地:

复杂度对比

  • 直接计算: 个点中每个需 次乘加,
  • FFT 实现:DFT/IDFT 各 ,加上逐元素乘积 ,总计

Cooley–Tukey FFT 在 时达到 ,是经典的高效算法。5

2.3 DFT 矩阵与正交性

点 DFT 的矩阵形式是:

是酉矩阵(差一个 归一化因子),。这意味着循环卷积对应的是正交基下的乘法,是数值稳定的线性变换。

更一般地,循环矩阵 可被 对角化:

这是后续在频域高效计算的关键代数恒等式。

2.4 与 群卷积的关系

循环卷积在群论视角下是有限循环群 上的卷积:

这一定义与上述离散形式完全一致。当 时,DFT 把这个群代数 分解为不可约表示(一维特征标 )的直和,圆卷积定理正是这一分解的代数表述。

2.5 与标准(线性)卷积的关系

圆卷积与线性卷积的关系是:

其中 在末尾零填充至长度 。如果序列在两端不”卷起来”,则需要零填充以避免回绕造成的伪影(见 §4.4 边界处理)。


3. CAT 的核心思想

3.1 Self-Attention 的循环重表述

回到自注意力。设 ,CAT 用一个单一的标量投影 把输入映射到一个 维向量:

再施加逐行(实为逐元素)softmax:

关键观察 被视作一个循环矩阵 第一行

恰好是 与每一列 的一维圆卷积的批量形式

3.2 QK^T 与循环卷积的等价性

为什么能用循环矩阵替代 ?我们可以从两个角度理解:

角度 1: 的特例
(即合并 query 和 key 投影),则 的第 元素为:

这是一个对称的双线性形式,但本身仍不是循环卷积。

角度 2:参数共享后的显式循环化
当我们只用一个标量投影 时,“QK^T”被替换为单标量 ,注意力矩阵变成 ——每一行都是第一行的循环移位,这就是循环矩阵。

论文中的数学描述十分精确:

Specifically, we learn a single projection matrix to map the input into a vector . We then apply a row-wise softmax, yielding . This serves as the first row of a circulant matrix , thus representing global pairwise interactions.3

与原始 attention 的等价性条件

  • 序列本身是周期性的(或可视为环形数据,如长视频、循环语料)时,圆卷积准确。
  • 当序列是有限线性(如一段文字)时,需要在序列两端做适当处理,详见 §4.4。

3.3 softmax--V 的频域高效计算

将 softmax 后的循环注意力矩阵乘以 ,结合圆卷积定理得到:

算法伪代码

Algorithm CAT_forward(X, W_A, W_V)
    Input:  X ∈ R^{N×D}, W_A ∈ R^{D×1}, W_V ∈ R^{D×D}
    Output: O ∈ R^{N×D}
 
    1. Z  ← X @ W_A                 // [N, 1]
    2. Z* ← softmax(Z, dim=0)       // [N, 1],  row-wise softmax
    3. V  ← X @ W_V                 // [N, D]
    4. Ỹ  ← FFT(V, dim=0)           // [N, D]  (复数)
    5. Ẑ  ← FFT(Z*, dim=0)          // [N]     (复数)
    6. Ŷ  ← Ẑ ⊙ Ỹ                   // [N, D]  (复数, Hadamard)
    7. O  ← IFFT(Ŷ, dim=0).real     // [N, D]
    8. return O

复杂度分析(对长度 维度):

操作复杂度备注
Z = X @ W_A 投影
softmax(Z)
V = X @ W_V标准值投影
FFT(V)每列独立 FFT
FFT(Z*)
Hadamard
IFFT
是次二次

相比标准注意力的 ,CAT 把对 的依赖从 降到

3.4 多头扩展

CAT 天然支持多头注意力:每头维护独立的 个头的输出拼接后过输出投影:

其中 。这与标准 MHA 形式完全一致,且每头都是


4. 完整数学推导

4.1 从标准注意力到 CAT

。标准多头注意力的单头输出为:

步骤 1:合并 query 与 key。
,并定义:

论文中称此为 Averaged-Key 参数化。我们以更简洁的合并形式()表达。

步骤 2:循环化。
视作循环矩阵 的生成向量。形式上:

步骤 3:频域化。
利用圆卷积定理:

,其中 沿序列维度(dim=0)作用。

4.2 FFT/IFFT 细节

在 PyTorch 中实数序列的 FFT 形式为:

  • torch.fft.rfft(x, dim=0):实数输入,输出 个复数(利用共轭对称性,节省一半存储)。
  • torch.fft.irfft(y, n=N, dim=0):从 个复数恢复 个实数。

对于复数乘法,PyTorch 自动广播。关键的精度保证是:实数输入 + 圆卷积 + 复数 FFT → 频域乘积在数值上与实数域卷积结果在浮点精度内一致(无离散化误差)。

4.3 数值稳定性与 NaN 防护

实现中需要处理以下数值问题:

  1. Softmax 溢出:当 很大时 会溢出。标准做法是减去最大值
  2. FFT 中的 与 NaN:若输入含 NaN,FFT 会传播 NaN。务必在前向计算前做 torch.nan_to_num 或断言。
  3. rfft 输入为复数时的实部/虚部拼接:论文实现中常将 视为”复数”形式以减少一次 IFFT。一种实现是把 沿最后一维复制到 的”实部”,或者直接对 单独做 rfft。
  4. 复数张量类型:使用 torch.complex64(与 float32 配套)或 torch.complex128。在 GPU 上 complex64 的实部/虚部分别占用 float32 内存,但 cuFFT 需要专门的 kernel。

4.4 边界处理

循环卷积假设序列在两端首尾相连。当处理线性序列时,需要补零(zero-padding):

  • 各填充至 长度:
  • FFT 后做 Hadamard 乘积,IFFT 取前 项即为线性卷积(无回绕)。

但这会使长度变 ,FFT 成本增至 。实践中,论文作者发现直接圆卷积(不做补零)反而效果更好——可能因为:

  • 真实序列(如图像 patch、词序列)天然存在弱周期性
  • 强制首尾不连接会引入边界伪影。
  • 论文 §5.2 的实验显示圆卷积在 ViT avg pooling 下优于标准 attention。

下表总结了边界处理策略的取舍:

策略复杂度边界行为适用场景
纯圆卷积首尾相连、可能回绕图像 patch、视频、循环语料
零填充到 等价线性卷积严格线性序列
因果移位强制因果自回归 LM(见 §7)

4.5 与 Linear Attention 的对比数学

Linear Attention 用核函数 近似为 。计算顺序倒置后:

复杂度 ,对 线性。但代价是:

  • 仅为 softmax 的低秩近似(FAVOR+ 用正/负随机特征),无法精确表达 softmax 形式。
  • 在长序列上训练不稳定、精度下降。

CAT 与 Linear Attention 的核心区别:

维度Linear AttentionCAT
复杂度
Softmax 形式仅近似精确保持
矩阵结构循环矩阵
训练稳定性在大模型上常 NaN报告稳定
全局感受野全局(核近似)全局(精确)

可以视 CAT 为保持 softmax 的次二次折中——比 Linear Attention 多一个 因子,但换来精确性。

4.6 与 Low-Rank 注意力压缩的关系

本仓库的 注意力矩阵低秩压缩方法假设注意力矩阵 存在低秩结构 。CAT 的视角正交:CAT 不是近似 ,而是直接构造一个结构化(循环)

联系 SVD 与谱方法 角度看,循环矩阵的特征分解:

意味着 CAT 在 DFT 基下只有 个自由度(即 个复数),远少于一般 矩阵的 个自由度。这与低秩方法在”减少注意力参数”上有共同目标,但路径不同:低秩通过 SVD 截断,CAT 通过结构对称性。


5. 实现细节

5.1 基础循环卷积 attention 实现

下面是一个最简的 PyTorch 实现,用于教学目的。Gather 版本通过显式构造循环矩阵、避免大矩阵乘法而实测比 naive matmul 快约 10%。

import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
def circular_shift_matrix(z_star: torch.Tensor) -> torch.Tensor:
    """
    根据首行 z_star 构造 N×N 循环矩阵。
    Input:  z_star [N, 1]
    Output: M      [N, N]  M[i, j] = z_star[(j - i) mod N]
    """
    N = z_star.shape[0]
    idx = (torch.arange(N, device=z_star.device).unsqueeze(0)
           - torch.arange(N, device=z_star.device).unsqueeze(1)) % N
    return z_star[idx]  # [N, N, 1] -> squeeze 最后一维
 
 
class CATGather(nn.Module):
    """
    Gather 版本的 CAT (O(N^2) 但 cache 友好)。
    论文报告在 ViT CLIP-L 上对标准 attention 提速约 10%。
    """
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        # 合并的 query-key 投影, 每个 head 一列
        self.W_A = nn.Parameter(torch.randn(n_heads, d_model) * 0.02)
        # value 投影
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        # 输出投影
        self.W_O = nn.Linear(d_model, d_model, bias=False)
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [B, N, D]
        return: [B, N, D]
        """
        B, N, D = x.shape
        H, Dh = self.n_heads, self.d_head
 
        # 1) Z = X @ W_A: 对每个 head 投影到标量
        # x: [B, N, D] @ W_A^T: [D, H] -> [B, N, H]
        Z = x @ self.W_A.t()  # [B, N, H]
        # softmax 沿 N 维度
        Z_star = F.softmax(Z, dim=1)  # [B, N, H]
 
        # 2) V = X @ W_V: 标准值投影
        V = self.W_V(x)  # [B, N, D]
        V = V.view(B, N, H, Dh).transpose(1, 2)  # [B, H, N, Dh]
 
        # 3) 构造循环矩阵并与 V 相乘
        #    利用圆卷积 = 用 roll 构造移位 + 加权求和
        #    O_[:, :, i, :] = sum_m Z_star[:, (i-m) mod N, :] * V[:, :, m, :]
        #    高效实现: 用 unfold 或 gather
        out_heads = []
        for h in range(H):
            Zh = Z_star[:, :, h]        # [B, N]
            Vh = V[:, h]                 # [B, N, Dh]
            # 循环移位所有行后加权求和
            # shift_k: 沿 N 维 shift k 位
            acc = torch.zeros_like(Vh)
            for k in range(N):
                acc = acc + Zh[:, k:k+1] * torch.roll(Vh, shifts=k, dims=1)
            out_heads.append(acc)
        out = torch.stack(out_heads, dim=1)  # [B, H, N, Dh]
        out = out.transpose(1, 2).contiguous().view(B, N, D)
        return self.W_O(out)

注意上面的内层循环是 容易理解。下面用 FFT 把它降到

5.2 高效 FFT 版本

class CATFFT(nn.Module):
    """
    完整 FFT 版本的 CAT,O(N log N) 序列复杂度。
    使用 torch.fft.rfft 处理实数输入。
    """
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.W_A = nn.Parameter(torch.randn(n_heads, d_model) * 0.02)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [B, N, D]
        return: [B, N, D]
        """
        B, N, D = x.shape
        H, Dh = self.n_heads, self.d_head
 
        # 1) 投影得到 Z (合并 QK) 和 V
        Z = x @ self.W_A.t()           # [B, N, H]
        Z_star = F.softmax(Z, dim=1)   # [B, N, H]
        V = self.W_V(x).view(B, N, H, Dh)  # [B, N, H, Dh]
 
        # 2) 把 Z_star 复数化以便 broadcast 到 Dh 维
        #    方式 A: 用 complex tensor, H 维与 V 的 Dh 维对齐
        #    复数乘法 a * b = (a_re*b_re - a_im*b_im) + i(a_re*b_im + a_im*b_re)
        #    此处我们把 Z_star 当作实部, 虚部为 0
        Z_complex = torch.complex(Z_star, torch.zeros_like(Z_star))  # [B, N, H]
 
        # 3) 沿 N 维做 FFT
        #    rfft 对最后一维? 不, 我们对 dim=1 (N 维) 做
        #    torch.fft.fft 在指定 dim 上做, 返回完整 N 个复数
        Z_hat = torch.fft.fft(Z_complex, dim=1)      # [B, N, H]
        V_hat = torch.fft.fft(V, dim=1)              # [B, N, H, Dh], complex
 
        # 4) Hadamard 乘积: Z_hat 与 V_hat 沿 N 维逐元素乘
        #    需要 broadcast: Z_hat [B, N, H, 1] * V_hat [B, N, H, Dh]
        Z_hat_expanded = Z_hat.unsqueeze(-1)         # [B, N, H, 1]
        Y_hat = Z_hat_expanded * V_hat               # [B, N, H, Dh], complex
 
        # 5) IFFT 回到实数
        Y = torch.fft.ifft(Y_hat, dim=1).real        # [B, N, H, Dh]
 
        # 6) 拼 head + 输出投影
        out = Y.contiguous().view(B, N, D)
        return self.W_O(out)

复杂度分析

  • softmax + W_A @ X FLOPs
  • W_V @ X
  • FFT(Z_star) 复数 FLOPs
  • FFT(V)
  • Hadamard:
  • IFFT:同 FFT

总复杂度 ,对 是次二次。

5.3 与标准注意力的等价性测试

下面脚本测试:在退化情形下(无 softmax 偏移、长度 较小),CAT 的循环实现与”展开 “形式的差异应该仅来自浮点精度。

import torch
import torch.nn.functional as F
from typing import Tuple
 
 
def standard_attention_for_test(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    """
    标准 attention, 用于对照。
    Q, K, V: [B, N, D]
    """
    scale = Q.shape[-1] ** -0.5
    return F.softmax(Q @ K.transpose(-2, -1) * scale, dim=-1) @ V
 
 
def cat_circulant_attention(Z_star: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    """
    显式构造循环矩阵的 CAT, 用于验证.
    Z_star: [B, N, 1]
    V:      [B, N, D]
    """
    B, N, _ = Z_star.shape
    D = V.shape[-1]
    # 构造循环矩阵, M[b, i, j] = Z_star[b, (j-i) mod N, 0]
    idx = (torch.arange(N).unsqueeze(0) - torch.arange(N).unsqueeze(1)) % N  # [N, N]
    M = Z_star[:, idx, 0]  # [B, N, N]
    return M @ V  # [B, N, D]
 
 
def test_equivalence():
    torch.manual_seed(0)
    B, N, D = 2, 8, 16
    # 构造 Q, K, V
    X = torch.randn(B, N, D)
    W_Q = torch.randn(D, D) * 0.1
    W_K = torch.randn(D, D) * 0.1
    W_V = torch.randn(D, D) * 0.1
    Q, K, V = X @ W_Q, X @ W_K, X @ W_V
 
    # 构造 "circulant 化" 的 Q, K: 让 Q = K (合并) 并降维
    # 这时 QK^T 是秩 1 矩阵, 但我们用循环化近似
    # 注意: 这只是"概念"测试, 实际 CAT 不直接等于标准 attention
    Z = (Q + K) / 2  # 模拟 query+key 合并
    Z_star = F.softmax(Z.sum(dim=-1, keepdim=True), dim=1)  # [B, N, 1]
 
    out_standard = standard_attention_for_test(Q, K, V)
    out_circulant = cat_circulant_attention(Z_star, V)
    # 两者并不相等 (结构本就不同), 仅 sanity check shapes
    assert out_standard.shape == out_circulant.shape
    diff = (out_standard - out_circulant).norm() / out_standard.norm()
    print(f"relative L2 diff (expected non-zero): {diff.item():.4f}")
 
 
def test_fft_correctness():
    """验证 cat_fft 与 cat_circulant 在数值上等价."""
    torch.manual_seed(1)
    B, N, D = 1, 16, 32
    Z_star = F.softmax(torch.randn(B, N, 1), dim=1)
    V = torch.randn(B, N, D)
 
    # 循环矩阵路径
    out_circ = cat_circulant_attention(Z_star, V)
 
    # FFT 路径
    Z_complex = torch.complex(Z_star, torch.zeros_like(Z_star))
    Z_hat = torch.fft.fft(Z_complex, dim=1).unsqueeze(-1)  # [B, N, 1, 1]
    V_hat = torch.fft.fft(V, dim=1).unsqueeze(2)            # [B, N, 1, D]
    Y_hat = Z_hat * V_hat
    out_fft = torch.fft.ifft(Y_hat, dim=1).real.squeeze(2)  # [B, N, D]
 
    err = (out_circ - out_fft).abs().max().item()
    print(f"FFT vs Circulant max abs error: {err:.2e}  (should be ~1e-6)")
 
 
if __name__ == "__main__":
    test_equivalence()
    test_fft_correctness()

预期输出类似:

relative L2 diff (expected non-zero): 0.43xxx
FFT vs Circulant max abs error: 1.19e-07  (should be ~1e-6)

第二个测试应该显示 FFT 路径与显式循环矩阵路径完全一致(仅浮点误差),这是 §2 圆卷积定理的直接推论。

5.4 Triton/CUDA 优化思路

虽然 PyTorch 的 torch.fft.fft 已经调用了 cuFFT,但仍有几个优化空间:

  1. 融合的 Hadamard + IFFT kernel:把乘积与 IFFT 合并到一个 Triton kernel 中,避免中间结果写入显存。
  2. 实数 FFT 的 half-precision:在 A100/H100 上,torch.fft.rfft 支持 float16 输入,但需注意数值范围。
  3. 多 head 共享一次 FFT:对 而言,多个 head 共享同一份实数序列的 FFT 结果。
  4. 与 FlashAttention 兼容:CAT 的核心是圆卷积而不是 风格,难以直接复用 FlashAttention 的 tiling。但对于输出 的部分,可以借鉴 FlashAttention-2 的在线 softmax 思想在频域重写。

论文 §C 报告了与 FlashAttention 兼容性的初步结果:在 ViT 上可以混合 CAT 层与标准 attention 层(CAT-Alter),总速度可超过纯 FlashAttention 的 baseline。

5.5 内存占用对比

方法中间张量峰值显存
Standard Attention
Linear Attention
CAT (FFT)

CAT 完全没有 项,是真正的次二次显存使用。这对超长序列推理尤其关键。


6. 实验结果

6.1 视觉任务(ImageNet-1k)

论文采用 ViT CLIP-B(12 heads)与 CLIP-L(16 heads),分别在 token pooling([CLS] token)与 avg pooling(全局平均)下做 ImageNet-1k 分类,结果如下(节选自论文 Tab. 1):

ModelPoolMechanismLearnableComplexityAcc.↑
CLIP-BtokenAttention0.574
CLIP-BtokenCAT0.540
CLIP-BtokenCAT-Alter0.582
CLIP-LtokenAttention0.574
CLIP-LtokenCAT0.559
CLIP-LtokenCAT-Alter0.593
CLIP-BavgAttention0.638
CLIP-BavgCAT0.649
CLIP-BavgCAT-Alter0.662
CLIP-LavgAttention0.646
CLIP-LavgCAT0.694
CLIP-LavgCAT-Alter0.681

关键发现

  1. avg pooling > token pooling:在所有方法中,avg pooling 一致优于 [CLS] token pooling。这与近期 ViT 研究的趋势一致(如 SepViT、Patch Merger)。
  2. CAT 在 avg pooling 下超过标准 attention:CLIP-L 上 CAT (0.694) vs Attention (0.646),提升 +4.8 个百分点
  3. CAT-Alter 是最稳健的折中:在 token pooling 下 CAT-Alter 也能反超 attention。

6.2 语言任务(WikiText-103)

论文使用 Transformer-XL(10 heads)与 GPT-2 small(12 heads),分别在 masked 与 causal 语言建模下评估,指标为 word perplexity(越低越好):

ModelLM TypeMechanismComplexityWord PPL↓
Transformer-XLmaskedAttention13.94
Transformer-XLmaskedCAT10.28
Transformer-XLmaskedCAT-Alter8.51
GPT-2 smallmaskedAttention9.82
GPT-2 smallmaskedCAT8.32
GPT-2 smallmaskedCAT-Alter7.54
Transformer-XLcausalAttention30.82
Transformer-XLcausalCAT36.71
Transformer-XLcausalCAT-Alter30.93
GPT-2 smallcausalAttention27.84
GPT-2 smallcausalCAT32.36
GPT-2 smallcausalCAT-Alter27.68

关键发现

  1. masked LM 中 CAT 显著优于 Attention:Transformer-XL 上 PPL 从 13.94 降到 10.28(-26%),GPT-2 small 上从 9.82 降到 8.32(-15%)。
  2. causal LM 中纯 CAT 退化:从 30.82 退化到 36.71。原因是强制因果性会破坏循环对称性(见 §7.2 边界讨论),导致需要显式 mask,复杂度退化到
  3. CAT-Alter 在因果 LM 上也能与 attention 持平:这是工程上最实用的方案。

6.3 与 Linear / Sparse Attention 的对比

论文 §5.3 报告了与 Linear Attention(FAVOR+ / Katharopoulos 形式)和稀疏注意力的对比尝试:

  • Linear Attention (CLIP-L):作者在 ImageNet-1k 上反复遇到 NaN 损失,无法稳定收敛。这与文献中报道的”kernel-based attention 在大模型上训练不稳定”一致。
  • 稀疏注意力 (BigBird / Longformer):作者认为其引入了序列长度依赖的超参数(块大小、窗口宽度),违反 EIT 框架的”极简超参”原则,因此未做主表对比。

这凸显 CAT 的一个核心优势:保持 softmax 形式 → 训练稳定 + 少超参

6.4 推理速度 Benchmark

论文报告的关键速度数据(ViT CLIP-L,N=256,NVIDIA V100,naive PyTorch):

方法单次迭代时间加速比
Standard Attention1.00×(基线)1.00
CAT (Gather)0.90×1.11×
CAT (FFT)1.00×1.00(无明显加速)

解读

  • Gather 版本实测加速 10%——这看似违反”FFT 才是 “的论断。原因: 时 FFT 调用与 kernel launch 开销已经接近 matmul 开销,FFT 优势要 才显现。
  • 在更大的 或专用 FFT kernel 下,FFT 版本的理论优势会更明显(论文留作未来工作)。
  • CAT-Alter 加速介于两者之间(因为一半层是标准 attention)。

7. 理论分析

7.1 表达力:能否逼近任意 attention?

CAT 的表达力受限于两个结构性约束:

  1. 标量投影 只能生成 个标量权重,而标准 attention 的 的双线性。
  2. 循环对称性强制每一行是首行的循环移位,无法表达”非平移不变”的注意力模式。

因此 CAT 不能逼近任意 attention。形式上:

而标准 attention 矩阵的秩可以高达 这意味着 CAT 的”全局 softmax”是由循环结构约束的全局性,而非任意配对权重。

但实际效果可与 Attention 相当甚至更好,原因可能在于:

  • 训练数据中”长程关联” 本身常表现为移位不变模式(如位置编码、卷积特征)。
  • 循环结构作为强归纳偏置,相当于正则化,减少过拟合。
  • 注意力谱理论 中的”rank collapse”现象不同,CAT 通过循环结构主动控制秩。

7.2 局部性:圆卷积天然编码环形局部结构

圆卷积最直接的属性是局部平移等变性(local shift-equivariance):

这意味着如果输入序列发生循环平移,CAT 的输出也按相同方式平移——这一点与离散平移群 的作用自然兼容。

早期 Transformer 层的注意力图常表现出”局部条带 + 少量全局”模式(BERT 早期层可视化,Clark et al. 2019)。CAT 通过循环结构直接编码这种局部性,无需学习。这是它在小数据 / 早期层表现优异的原因之一。

7.3 与正交变换的连接(DFT 矩阵)

DFT 矩阵 是酉矩阵,因此:

循环矩阵在 下对角化:

这说明 CAT 在频域做的是对角缩放(Hadamard 乘积的频域形式),是一种正交保持的线性变换。这与 SVD 与谱方法 中的”在正交基下做对角操作”思路一脉相承——CAT 把注意力矩阵的对角化显式构造为 DFT。

7.4 不同序列长度下的缩放行为

对长度为 的序列,CAT 的缩放关系为:

组件复杂度
投影(
softmax
FFT(每列)
IFFT(每列)

注意 是分离的:CAT 在固定 下对

与各类方法的渐近对比

方法时间复杂度空间复杂度
Standard Attention
Linear Attention
Sparse (Longformer) = 窗口)
State Space (Mamba)
CAT
CAT-Alter(混合)

极大()时,CAT 的 比 Linear / SSM 的 多一个 因子,但换来精确的 softmax 形式——这是经典的时间-精度折中。

7.5 与 Kernel Method 的统一视角

注意力的”kernel view”把 视作一个 Mercer 核:

Linear Attention 用随机 Fourier 特征 近似

CAT 的视角更结构化:它把 限制为循环形式(即 ),用循环对称性换取显式实现。

可以认为 CAT 是 kernel approximation with structural prior on shift-invariance。这种结构先验在数据本身具有平移结构时是合适的(图像、语音、时间序列),但在 NLP 等更抽象的领域可能成为限制——这或许是 CAT 在 WikiText-103 causal LM 上退化的原因。


8. 相关工作

CAT 处于次二次注意力研究的核心位置。下表对主流方法做横向比较:

方法年份复杂度Softmax 形式全局性参数量
Standard Attention2017精确全局
Sparse Transformer2019精确(局部)局部 + 少量全局
Longformer2020精确(局部)局部 + 全局 token
BigBird2020精确(局部)局部 + 随机 + 全局
Linear Attention2020近似全局(核近似)
Performer (FAVOR+)2021近似全局(核近似)
Hyena2023全局(卷积)
RetNet2023隐式全局
Mamba2023全局
Monarch Mixer2023块结构块稀疏
CAT (本篇)2025精确全局

8.1 FAVOR+ (Performer)

Choromanski et al. (2021) 用正/负随机 Fourier 特征近似 softmax 核:

实现 复杂度。CAT 不做核近似,保留精确 softmax,是另一条路径。

8.2 Hyena Operator

Poli et al. (2023) 提出的 Hyena 用长卷积 + 门控替代注意力,本质是”无 softmax 的次二次模型”。与 CAT 的对比:

  • Hyena 完全无 softmax,通过隐式参数化学习长程依赖。
  • CAT 保留 softmax 但用循环结构约束矩阵形式。

Hyena 的复杂度 与 CAT 相同,但表达力假设不同:Hyena 假设数据有强卷积结构,CAT 假设数据有强平移不变性。

8.3 Monarch Mixer

Dao et al. (2023) 的 Monarch Mixer 用块对角 + FFT 的混合矩阵做 token 混合,是另一种结构化矩阵路径。与 CAT 的循环矩阵相比,Monarch 更灵活(块大小可调)但没有 softmax 形式

8.4 与本仓库其他文献的交叉

CAT 的思想在多个维度上与本仓库已有内容呼应:

  • 线性注意力机制理论:与 Linear Attention 的核方法路径形成对照,可视作”保留 softmax 的次二次路径”。
  • 线性注意力新变体(2024-2025):2024-2025 年的最新进展(KDA、ZeroS、RACE、Log-Linear)大多沿”Linear + 改进核”路径;CAT 走的是截然不同的”循环 + FFT”路径。
  • 注意力矩阵低秩压缩:低秩方法通过 SVD 截断减少参数,CAT 通过循环对称性减少参数;两者都关心”用更少参数表达 ”。
  • SVD 与谱方法:CAT 在 DFT 基下做对角化是 SVD 思路的变体。
  • Flash Attention:与 FlashAttention 的”减少 IO 次数”思路正交,CAT 是”减少理论 FLOPs”。
  • 注意力变体比较:在更宏观的注意力机制谱系中,CAT 属于”结构化矩阵 + 精确 softmax”派。

8.5 近期相关进展

  • Spectformer (2023):用频域混合做全局 token 混合,与 CAT 有相似的动机。
  • FourierFormer (2022):用 Fourier 变换参数化 attention,与 CAT 共享 FFT 思想但保留 形式。
  • GFNet (2021):在频域做 token 混合,无 softmax。

9. 实践指南

9.1 何时选用 CAT

适合 CAT 的场景

  • 图像任务(ViT 类):patch 序列天然有”局部 + 弱周期”结构,CAT 的循环归纳偏置吻合。
  • 中等长度序列):Gather 版本就有 10% 加速。
  • 超长序列推理 显存占用是关键优势。
  • Masked LM:实验显示 CAT 在 masked 设定下显著优于标准 attention。
  • 显存受限训练:相比 Linear Attention 没有 NaN 问题。

不适合 CAT 的场景

  • 纯自回归 LM(causal-only):循环对称性被破坏,需要显式 mask 重新退化为
  • 需要任意配对交互的任务:如跨句指代、复杂结构推理。
  • 极小模型):参数效率优势不显著,归纳偏置反而成为限制。
  • 对绝对位置敏感的任务:CAT 偏向相对/平移结构。

9.2 常见陷阱

  1. 初始化 初始化过大会导致 softmax 饱和, 近似 one-hot,圆卷积退化为”取一个特定移位的 V”。建议小初始化()。
  2. 学习率:CAT 报告在原论文设置下无需特殊调整。但若混入预训练权重, 的初始梯度可能很小,需要 warmup。
  3. 批大小:FFT 在大 batch + 中等 下非常高效;如果 batch=1,可能需要把多序列打包。
  4. 混合精度torch.fft.fftfloat16 下可能损失精度,建议 Z_star 保持 float32,只在 IFFT 阶段转回 float16
  5. Padding 长度:若序列需要 pad 到 batch 内最大长度,循环卷积会”卷”到 pad 区域。最佳做法是按真实长度 FFT 后再 mask。

9.3 与预训练权重的兼容性

CAT 的参数空间与标准 attention 部分重叠

  • :可直接复用预训练的 value 投影。
  • :新参数,需从头训练。可用 初始化( 是单位向量)来接近预训练分布。
  • :可直接复用预训练的输出投影。

微调流程建议:

  1. 加载预训练标准 attention 权重到
  2. 的均值初始化
  3. 用较小学习率(如基础 LR 的 0.1×)微调
  4. 全模型用正常 LR 微调 1-2 epoch。

9.4 训练稳定性建议

  • 梯度裁剪:CAT 的频域操作对大幅度梯度敏感,建议 clip_grad_norm_(parameters, 1.0)
  • 学习率 warmup:5-10% 步数的 warmup 能稳定早期 的学习。
  • Layer-scale:若与 [Touvron et al. 2021] 风格的可学习 layer-scale 结合,建议初始化为 0.1 而非 1.0,因为循环结构天然放大输出幅度。
  • Dropout:在 上加 dropout 比在 上加更稀疏(因为 个标量 vs 个元素)。建议保持 dropout 概率 0.1。

9.5 推荐的工程设置

组件推荐值
初始化
初始化 Xavier / Kaiming
学习率(CAT-only)与 baseline 相同
学习率(CAT-Alter)与 baseline 相同
FFT 长度(效率)或 (精确)
Dropout(0.1
梯度裁剪1.0
OptimizerAdamW(

10. 局限性、争议与开放问题

10.1 论文自陈局限

论文 §7 明确指出:

  1. 极长序列尚未验证:CAT 在 8K+ 序列上的稳定性和精度仍是开放问题。
  2. 硬件优化的 FFT kernel 尚未实现:当前 naive PyTorch 实现在 下 FFT 版本没有明显加速。
  3. 因果性破坏循环结构:causal LM 下纯 CAT 退化为
  4. 表达力受限:循环结构本身是对注意力矩阵的强约束。

10.2 社区视角的争议

  • “循环 = 卷积 = 失去灵活性”:部分研究者认为循环结构过于限制,难以表达现代 LLM 需要的复杂关系。
  • “10% 加速不够颠覆”:相比 Performer、Mamba 等 方法,CAT 的 在工程上没有压倒性优势。
  • “不兼容 KV cache”:循环结构使得标准的增量推理 (KV cache) 难以直接应用;论文作者通过重写 Z 的偏移来部分支持,但增加了复杂度。

10.3 开放问题

  1. 混合架构的最优比例:CAT-Alter 在论文中显示混合 50% 是好的,但更系统的搜索(如 Hymba 的 head-level 混合)可能更好。
  2. 与 SSM 的统一:CAT 保留了 softmax,SSM 没有——能否在频域统一两者的优点?
  3. 多尺度 CAT:当前 CAT 用单一循环核 。能否做多尺度(多个 拼接,类似 multi-head 但每个 head 不同尺度)?
  4. 可解释性:循环注意力矩阵的特征值就是 (DFT 后),这给了频谱可解释性——能否用于模型分析?

11. 总结

CAT 提供了一个在保持 softmax 形式的前提下实现 自注意力的优雅方案。其核心思想——把注意力矩阵重写为循环矩阵、利用圆卷积定理在频域计算——是结构化矩阵方法与深度学习结合的典范。

核心要点

  1. 数学洞察:QK^T 注意力可重写为圆卷积
  2. 算法核心:用 再 IFFT,复杂度
  3. EIT 框架:次二次 + 精确 softmax + 少超参。
  4. 实验优势:在 ImageNet avg pooling、WikiText masked LM 上优于标准 attention。
  5. 工程友好:naive PyTorch 即可实现,10% 加速无特殊 kernel 要求。

注意力变体谱系 中,CAT 占据一个独特位置:它不是”更激进的近似”(如 Linear / Sparse),而是”更结构化的精确”——用对称性换效率。

未来的研究可能沿两个方向展开:

  • 算法侧:与 FlashAttention、SSM、KV cache 的深度融合。
  • 应用侧:在长视频理解、基因组学、代码生成等长序列任务上的实际部署。

12. 参考文献

Footnotes

  1. Vaswani et al. Attention is All You Need. NeurIPS 2017. 标准自注意力的原始定义,本文 §1.1 的公式直接沿用。

  2. Tay et al. Efficient Transformers: A Survey. ACM Computing Surveys 2022. 系统综述了 注意力在长序列上的瓶颈及各类次二次方案。

  3. Yamada, Y. CAT: Circular-Convolutional Attention for Sub-Quadratic Transformers. NeurIPS 2025, arXiv:2504.06704. 本文核心引用,所有”循环化”与”圆卷积”的数学细节均来自该论文 §4。 2

  4. Yamada 2025, §2. EIT(Engineering-Isomorphic Transformer)框架的四条公理:softmax 保持、次二次、参数高效、极简超参。

  5. Cooley & Tukey. An Algorithm for the Machine Calculation of Complex Fourier Series. Mathematics of Computation 1965. FFT 的原始论文,本文 §2.2 的复杂度结论基于此。