注意力机制的线性代数视角

引言

Transformer架构的核心是缩放点积注意力(Scaled Dot-Product Attention)机制:

从线性代数的视角审视这个公式,我们可以揭示其深层数学结构,理解为什么注意力机制如此强大,以及如何对其进行优化和分析。


1. Q/K/V 投影的几何意义

1.1 线性变换作为基变换

Query、Key、Value矩阵通过可学习的投影矩阵得到:

几何解释

  • :输入序列的矩阵表示(个token,每个维)
  • :三个不同的线性变换
  • 每个变换定义了输入空间的一个新的”坐标系”
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Q/K/V 投影矩阵
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # 输出投影
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, N, D = x.shape
        
        # Q/K/V 投影
        Q = self.W_q(x)  # (B, N, D)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # 分割多头
        Q = Q.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        
        # 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)
        
        # 合并多头并输出投影
        out = out.transpose(1, 2).contiguous().view(B, N, D)
        return self.W_o(out)

1.2 子空间解释

Query空间:定义了”需要查找什么”——每个query向量表示当前位置想要获取的信息类型

Key空间:定义了”提供什么”——每个key向量表示该位置包含的信息特征

Value空间:定义了”实际内容”——存储该位置的实际信息

投影矩阵 控制了如何将输入转换为可比较的表示,而 控制了如何聚合信息。


2. 点积相似度的代数结构

2.1 注意力分数矩阵

其中 ,所以 是一个 Gram矩阵(内积矩阵)。

Gram矩阵的性质

  1. 对称性(当 时)
  2. 半正定性
  3. 谱特性 的特征值非负

2.2 缩放因子的数学动机

问题:当 很大时,点积的值往往很大,导致 softmax 进入饱和区域(梯度几乎为零)。

解决方案:除以

直观理解

  • 假设 的各分量独立同分布,均值为0,方差为1
  • 的均值为0,方差为
  • 除以 后,方差恢复为1
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    缩放点积注意力的数学实现
    
    Q: (batch, n_heads, seq_len, d_k)
    K: (batch, n_heads, seq_len, d_k)
    V: (batch, n_heads, seq_len, d_v)
    """
    d_k = Q.shape[-1]
    
    # 计算点积并缩放
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 应用掩码(如果有)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # softmax 归一化
    attn_weights = F.softmax(scores, dim=-1)
    
    # 加权求和
    output = torch.matmul(attn_weights, V)
    
    return output, attn_weights

2.3 余弦相似度视角

点积注意力可以理解为一种缩放的余弦相似度:

如果我们假设 (通过归一化实现),则:

这解释了为什么注意力机制本质上是在寻找方向相似的向量。


3. softmax 归一化的信息论解释

3.1 softmax 的概率解释

注意力权重可以解释为条件概率分布:

这意味着输出是value向量的加权平均,权重由注意力分布决定。

3.2 最大熵原理

softmax 函数满足最大熵原理:给定注意力分数的均值和方差,softmax 生成的信息量最大的概率分布。

数学推导:在约束 下,最大化熵 的解正是 softmax。

3.3 温度参数

通用形式的 softmax 包含温度参数

温度 效果类比
退化为硬最大(one-hot)确定性选择
标准 softmax平衡探索与利用
均匀分布完全随机
def softmax_with_temperature(scores, temperature=1.0):
    """带温度的 softmax"""
    return F.softmax(scores / temperature, dim=-1)

4. 注意力矩阵的谱特性

4.1 注意力矩阵的性质

注意力矩阵 具有以下数学性质:

  1. 行随机性:每行和为1(
  2. 非负性:所有元素非负
  3. 谱半径:最大特征值 = 1(由 Perron-Frobenius 定理保证)

4.2 幂迭代与主导特征向量

当多次应用注意力(如多层堆叠)时,行为类似于幂迭代法:

收敛性质

  • 如果 是不可约的(非周期性),则 收敛到主特征向量
  • 主特征向量对应于 Markov 链的稳态分布

4.3 秩崩溃问题

当注意力矩阵的秩降低时,会出现”秩崩溃”(Rank Collapse)问题:

def analyze_attention_rank(A, eps=1e-6):
    """分析注意力矩阵的有效秩"""
    # 计算 SVD
    U, S, Vh = torch.linalg.svd(A, full_matrices=False)
    
    # 有效秩
    effective_rank = (S > eps).sum().item()
    
    # 谱熵(衡量分布均匀程度)
    S_norm = S / S.sum()
    spectral_entropy = -(S_norm * torch.log(S_norm + eps)).sum().item()
    
    return {
        'effective_rank': effective_rank,
        'spectral_entropy': spectral_entropy,
        'condition_number': S[0] / (S[-1] + eps)
    }

秩崩溃的影响:当注意力权重几乎均匀或几乎 one-hot 时,网络的表示能力严重受限。


5. 注意力作为信息路由

5.1 矩阵形式的统一视角

完整注意力计算可以表示为:

信息流解释

  • :计算query-key匹配度
  • softmax:归一化为概率分布
  • :待聚合的信息内容

5.2 输出作为值的加权平均

对于第 个输出位置:

这本质上是 的列向量的线性组合,权重由注意力分布决定。

5.3 自注意力的特殊情况

时,我们得到自注意力:

这相当于对输入序列本身进行自适应滤波


6. 多头注意力的矩阵结构

6.1 并行头结构

假设有 个注意力头,每个头的维度为

class MultiHeadAttentionMatrixForm(nn.Module):
    """
    多头注意力的矩阵形式实现
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 堆叠的投影矩阵 (用于高效计算)
        self.W_qkv = nn.Linear(d_model, 3 * d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, N, D = x.shape
        
        # 一次性计算 Q, K, V
        QKV = self.W_qkv(x)
        Q, K, V = QKV.split(D, dim=-1)
        
        # 重塑为多头形式
        Q = Q.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        
        # 计算注意力(批量矩阵乘法)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)
        
        # 合并多头
        out = out.transpose(1, 2).contiguous().view(B, N, D)
        return self.W_o(out)

6.2 多头注意力的表达能力

理论分析1

  • 单头注意力是秩-1近似的运算
  • 多头注意力可以表达更丰富的注意力模式
  • 不同头可以学习关注不同的语义关系

6.3 注意力头合并与冗余

实践中经常发现注意力头之间存在冗余:

def analyze_head_redundancy(attention_weights, n_heads):
    """
    分析多头之间的冗余程度
    使用相关系数衡量相似性
    """
    # attention_weights: (batch, heads, seq, seq)
    B, H, N, _ = attention_weights.shape
    
    # 计算头之间的相关性
    correlations = torch.zeros(H, H)
    for i in range(H):
        for j in range(H):
            # 重塑为向量并计算皮尔逊相关系数
            a = attention_weights[:, i].reshape(-1)
            b = attention_weights[:, j].reshape(-1)
            correlations[i, j] = F.cosine_similarity(
                a.unsqueeze(0), b.unsqueeze(0)
            )
    
    return correlations

7. 线性注意力与核方法

7.1 核函数视角

标准注意力的计算复杂度为 (序列长度的平方)。线性注意力通过核方法近似实现

使用核函数

7.2 线性注意力的递归形式

class LinearAttention(nn.Module):
    """
    线性注意力的实现(Performer/Linear Transformers)
    使用随机特征映射近似 softmax
    """
    def __init__(self, d_model, n_heads, n_features=256):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 随机特征映射
        self.W_rp = nn.Linear(self.d_k, n_features, bias=False)
        
        # Q/K/V 投影
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, N, D = x.shape
        
        Q = self.W_q(x).view(B, N, self.n_heads, self.d_k)
        K = self.W_k(x).view(B, N, self.n_heads, self.d_k)
        V = self.W_v(x).view(B, N, self.n_heads, self.d_k)
        
        # 随机特征映射
        Q_prime = F.relu(self.W_rp(Q))
        K_prime = F.relu(self.W_rp(K))
        
        # 累积计算(O(n) 复杂度)
        kv = torch.einsum('bnhd,bnhm->bhdm', K_prime, V)
        Z = K_prime.sum(dim=2)  # (B, H, D)
        
        # 最终输出
        out = torch.einsum('bnhd,bhdm->bnhm', Q_prime, kv)
        out = out / (Z.unsqueeze(1) + 1e-6)
        
        return out.view(B, N, D)

8. 实践:注意力可视化与分析

8.1 注意力权重可视化

import matplotlib.pyplot as plt
import seaborn as sns
 
def visualize_attention(attn_weights, tokens=None):
    """
    可视化注意力权重矩阵
    
    attn_weights: (seq_len, seq_len) 或 (n_heads, seq_len, seq_len)
    tokens: token 列表用于标签
    """
    if attn_weights.dim() == 3:
        # 多头:绘制所有头
        n_heads = attn_weights.shape[0]
        fig, axes = plt.subplots(1, n_heads, figsize=(4*n_heads, 4))
        for i, ax in enumerate(axes):
            sns.heatmap(attn_weights[i].cpu().numpy(), 
                       ax=ax, cmap='viridis', 
                       xticklabels=tokens, yticklabels=tokens)
            ax.set_title(f'Head {i+1}')
    else:
        # 单头
        fig, ax = plt.subplots(figsize=(8, 6))
        sns.heatmap(attn_weights.cpu().numpy(), 
                   ax=ax, cmap='viridis',
                   xticklabels=tokens, yticklabels=tokens)
    
    plt.tight_layout()
    return fig

8.2 注意力模式分析

def classify_attention_pattern(attn_weights):
    """
    分类注意力模式类型
    """
    # 计算对角线注意力
    diag_attn = attn_weights.diagonal(dim1=-2, dim2=-1).mean(-1)
    
    # 计算全局注意力(排除对角线)
    mask = torch.ones_like(attn_weights)
    mask.fill_diagonal_(0)
    global_attn = (attn_weights * mask).sum(-1) / (attn_weights.shape[-1] - 1)
    
    # 分类
    if diag_attn.mean() > 0.5:
        return "token-local"
    elif global_attn.mean() > 0.3:
        return "global"
    else:
        return "hierarchical"

9. 总结

核心要点

  1. Q/K/V 投影是三个不同的基变换,决定了注意力的匹配方式和信息表示
  2. 点积相似度本质上是余弦相似度的变体,捕捉向量间的方向关系
  3. softmax 归一化满足最大熵原理,生成信息量最大的概率分布
  4. 注意力矩阵是行随机矩阵,其谱特性影响网络的稳定性和表达能力
  5. 多头注意力通过并行组合多个注意力头,增强了模型表达能力

数学核心公式


参考资料


相关链接

Footnotes

  1. Vaswani, A., et al. (2017). Attention Is All You Need. NeurIPS 2017.