注意力机制的线性代数视角
引言
Transformer架构的核心是缩放点积注意力(Scaled Dot-Product Attention)机制:
从线性代数的视角审视这个公式,我们可以揭示其深层数学结构,理解为什么注意力机制如此强大,以及如何对其进行优化和分析。
1. Q/K/V 投影的几何意义
1.1 线性变换作为基变换
Query、Key、Value矩阵通过可学习的投影矩阵得到:
几何解释:
- :输入序列的矩阵表示(个token,每个维)
- :三个不同的线性变换
- 每个变换定义了输入空间的一个新的”坐标系”
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
# Q/K/V 投影矩阵
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# 输出投影
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
B, N, D = x.shape
# Q/K/V 投影
Q = self.W_q(x) # (B, N, D)
K = self.W_k(x)
V = self.W_v(x)
# 分割多头
Q = Q.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
# 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
# 合并多头并输出投影
out = out.transpose(1, 2).contiguous().view(B, N, D)
return self.W_o(out)1.2 子空间解释
Query空间:定义了”需要查找什么”——每个query向量表示当前位置想要获取的信息类型
Key空间:定义了”提供什么”——每个key向量表示该位置包含的信息特征
Value空间:定义了”实际内容”——存储该位置的实际信息
投影矩阵 控制了如何将输入转换为可比较的表示,而 控制了如何聚合信息。
2. 点积相似度的代数结构
2.1 注意力分数矩阵
其中 ,所以 是一个 Gram矩阵(内积矩阵)。
Gram矩阵的性质:
- 对称性:(当 时)
- 半正定性:
- 谱特性: 的特征值非负
2.2 缩放因子的数学动机
问题:当 很大时,点积的值往往很大,导致 softmax 进入饱和区域(梯度几乎为零)。
解决方案:除以
直观理解:
- 假设 的各分量独立同分布,均值为0,方差为1
- 则 的均值为0,方差为
- 除以 后,方差恢复为1
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
缩放点积注意力的数学实现
Q: (batch, n_heads, seq_len, d_k)
K: (batch, n_heads, seq_len, d_k)
V: (batch, n_heads, seq_len, d_v)
"""
d_k = Q.shape[-1]
# 计算点积并缩放
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# 应用掩码(如果有)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# softmax 归一化
attn_weights = F.softmax(scores, dim=-1)
# 加权求和
output = torch.matmul(attn_weights, V)
return output, attn_weights2.3 余弦相似度视角
点积注意力可以理解为一种缩放的余弦相似度:
如果我们假设 (通过归一化实现),则:
这解释了为什么注意力机制本质上是在寻找方向相似的向量。
3. softmax 归一化的信息论解释
3.1 softmax 的概率解释
注意力权重可以解释为条件概率分布:
这意味着输出是value向量的加权平均,权重由注意力分布决定。
3.2 最大熵原理
softmax 函数满足最大熵原理:给定注意力分数的均值和方差,softmax 生成的信息量最大的概率分布。
数学推导:在约束 和 下,最大化熵 的解正是 softmax。
3.3 温度参数
通用形式的 softmax 包含温度参数 :
| 温度 | 效果 | 类比 |
|---|---|---|
| 退化为硬最大(one-hot) | 确定性选择 | |
| 标准 softmax | 平衡探索与利用 | |
| 均匀分布 | 完全随机 |
def softmax_with_temperature(scores, temperature=1.0):
"""带温度的 softmax"""
return F.softmax(scores / temperature, dim=-1)4. 注意力矩阵的谱特性
4.1 注意力矩阵的性质
注意力矩阵 具有以下数学性质:
- 行随机性:每行和为1()
- 非负性:所有元素非负
- 谱半径:最大特征值 = 1(由 Perron-Frobenius 定理保证)
4.2 幂迭代与主导特征向量
当多次应用注意力(如多层堆叠)时,行为类似于幂迭代法:
收敛性质:
- 如果 是不可约的(非周期性),则 收敛到主特征向量
- 主特征向量对应于 Markov 链的稳态分布
4.3 秩崩溃问题
当注意力矩阵的秩降低时,会出现”秩崩溃”(Rank Collapse)问题:
def analyze_attention_rank(A, eps=1e-6):
"""分析注意力矩阵的有效秩"""
# 计算 SVD
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
# 有效秩
effective_rank = (S > eps).sum().item()
# 谱熵(衡量分布均匀程度)
S_norm = S / S.sum()
spectral_entropy = -(S_norm * torch.log(S_norm + eps)).sum().item()
return {
'effective_rank': effective_rank,
'spectral_entropy': spectral_entropy,
'condition_number': S[0] / (S[-1] + eps)
}秩崩溃的影响:当注意力权重几乎均匀或几乎 one-hot 时,网络的表示能力严重受限。
5. 注意力作为信息路由
5.1 矩阵形式的统一视角
完整注意力计算可以表示为:
信息流解释:
- :计算query-key匹配度
- softmax:归一化为概率分布
- :待聚合的信息内容
5.2 输出作为值的加权平均
对于第 个输出位置:
这本质上是 对 的列向量的线性组合,权重由注意力分布决定。
5.3 自注意力的特殊情况
当 时,我们得到自注意力:
这相当于对输入序列本身进行自适应滤波。
6. 多头注意力的矩阵结构
6.1 并行头结构
假设有 个注意力头,每个头的维度为 :
class MultiHeadAttentionMatrixForm(nn.Module):
"""
多头注意力的矩阵形式实现
"""
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 堆叠的投影矩阵 (用于高效计算)
self.W_qkv = nn.Linear(d_model, 3 * d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
B, N, D = x.shape
# 一次性计算 Q, K, V
QKV = self.W_qkv(x)
Q, K, V = QKV.split(D, dim=-1)
# 重塑为多头形式
Q = Q.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
# 计算注意力(批量矩阵乘法)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
# 合并多头
out = out.transpose(1, 2).contiguous().view(B, N, D)
return self.W_o(out)6.2 多头注意力的表达能力
理论分析1:
- 单头注意力是秩-1近似的运算
- 多头注意力可以表达更丰富的注意力模式
- 不同头可以学习关注不同的语义关系
6.3 注意力头合并与冗余
实践中经常发现注意力头之间存在冗余:
def analyze_head_redundancy(attention_weights, n_heads):
"""
分析多头之间的冗余程度
使用相关系数衡量相似性
"""
# attention_weights: (batch, heads, seq, seq)
B, H, N, _ = attention_weights.shape
# 计算头之间的相关性
correlations = torch.zeros(H, H)
for i in range(H):
for j in range(H):
# 重塑为向量并计算皮尔逊相关系数
a = attention_weights[:, i].reshape(-1)
b = attention_weights[:, j].reshape(-1)
correlations[i, j] = F.cosine_similarity(
a.unsqueeze(0), b.unsqueeze(0)
)
return correlations7. 线性注意力与核方法
7.1 核函数视角
标准注意力的计算复杂度为 (序列长度的平方)。线性注意力通过核方法近似实现 :
使用核函数 :
7.2 线性注意力的递归形式
class LinearAttention(nn.Module):
"""
线性注意力的实现(Performer/Linear Transformers)
使用随机特征映射近似 softmax
"""
def __init__(self, d_model, n_heads, n_features=256):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 随机特征映射
self.W_rp = nn.Linear(self.d_k, n_features, bias=False)
# Q/K/V 投影
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
def forward(self, x):
B, N, D = x.shape
Q = self.W_q(x).view(B, N, self.n_heads, self.d_k)
K = self.W_k(x).view(B, N, self.n_heads, self.d_k)
V = self.W_v(x).view(B, N, self.n_heads, self.d_k)
# 随机特征映射
Q_prime = F.relu(self.W_rp(Q))
K_prime = F.relu(self.W_rp(K))
# 累积计算(O(n) 复杂度)
kv = torch.einsum('bnhd,bnhm->bhdm', K_prime, V)
Z = K_prime.sum(dim=2) # (B, H, D)
# 最终输出
out = torch.einsum('bnhd,bhdm->bnhm', Q_prime, kv)
out = out / (Z.unsqueeze(1) + 1e-6)
return out.view(B, N, D)8. 实践:注意力可视化与分析
8.1 注意力权重可视化
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(attn_weights, tokens=None):
"""
可视化注意力权重矩阵
attn_weights: (seq_len, seq_len) 或 (n_heads, seq_len, seq_len)
tokens: token 列表用于标签
"""
if attn_weights.dim() == 3:
# 多头:绘制所有头
n_heads = attn_weights.shape[0]
fig, axes = plt.subplots(1, n_heads, figsize=(4*n_heads, 4))
for i, ax in enumerate(axes):
sns.heatmap(attn_weights[i].cpu().numpy(),
ax=ax, cmap='viridis',
xticklabels=tokens, yticklabels=tokens)
ax.set_title(f'Head {i+1}')
else:
# 单头
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(attn_weights.cpu().numpy(),
ax=ax, cmap='viridis',
xticklabels=tokens, yticklabels=tokens)
plt.tight_layout()
return fig8.2 注意力模式分析
def classify_attention_pattern(attn_weights):
"""
分类注意力模式类型
"""
# 计算对角线注意力
diag_attn = attn_weights.diagonal(dim1=-2, dim2=-1).mean(-1)
# 计算全局注意力(排除对角线)
mask = torch.ones_like(attn_weights)
mask.fill_diagonal_(0)
global_attn = (attn_weights * mask).sum(-1) / (attn_weights.shape[-1] - 1)
# 分类
if diag_attn.mean() > 0.5:
return "token-local"
elif global_attn.mean() > 0.3:
return "global"
else:
return "hierarchical"9. 总结
核心要点
- Q/K/V 投影是三个不同的基变换,决定了注意力的匹配方式和信息表示
- 点积相似度本质上是余弦相似度的变体,捕捉向量间的方向关系
- softmax 归一化满足最大熵原理,生成信息量最大的概率分布
- 注意力矩阵是行随机矩阵,其谱特性影响网络的稳定性和表达能力
- 多头注意力通过并行组合多个注意力头,增强了模型表达能力
数学核心公式
参考资料
相关链接:
Footnotes
-
Vaswani, A., et al. (2017). Attention Is All You Need. NeurIPS 2017. ↩