神经网络架构的统一矩阵框架
引言
深度学习领域涌现了大量专用架构:CNN处理图像、RNN处理序列、Transformer处理长距离依赖。从线性代数的视角,这些看似不同的架构实际上都是矩阵乘法的不同变体,它们之间的差异主要体现在矩阵的结构特性上。
核心观点:所有主流神经网络架构都可以用稀疏矩阵乘法来统一描述。卷积是Toeplitz矩阵乘法,循环是下三角矩阵乘法,注意力是全连接矩阵乘法。1
1. 从卷积到矩阵乘法
1.1 卷积的矩阵表示
二维卷积可以通过Toeplitz矩阵(块Toeplitz矩阵)表示为矩阵乘法:
对于输入 和卷积核 :
其中 是一个块Toeplitz矩阵, 是 的向量化形式。
1.2 im2col:卷积到矩阵乘法的转换
深度学习框架(如cuDNN)将卷积转换为矩阵乘法来利用GPU的矩阵运算优化:
def im2col(X, kernel_size, stride, padding):
"""
im2col:将卷积转换为矩阵乘法
思想:将每个滑动窗口展开为矩阵的一列
"""
B, C, H, W = X.shape
# 填充
if padding > 0:
X_padded = F.pad(X, (padding, padding, padding, padding))
else:
X_padded = X
# 计算输出尺寸
H_out = (H + 2*padding - kernel_size) // stride + 1
W_out = (W + 2*padding - kernel_size) // stride + 1
# 提取滑动窗口
cols = F.unfold(X_padded, kernel_size, stride=stride)
cols = cols.transpose(1, 2) # (B, H_out*W_out, C*k*k)
return cols, H_out, W_out
def conv_as_matmul(X, weight, stride=1, padding=1):
"""
将卷积实现为矩阵乘法
"""
B, C, H, W = X.shape
k, _, C_out, _ = weight.shape
# im2col转换
cols, H_out, W_out = im2col(X, k, stride, padding)
# 展平卷积核
weight_flat = weight.view(C_out, -1).T # (C*k*k, C_out)
# 矩阵乘法
out = cols @ weight_flat # (B, H_out*W_out, C_out)
# 重塑为图像格式
out = out.view(B, H_out, W_out, C_out).permute(0, 3, 1, 2)
return out1.3 Toeplitz矩阵结构
def create_toeplitz_matrix(kernel, input_size):
"""
创建Toeplitz矩阵
每一行是卷积核的循环移位
"""
k = len(kernel)
n = input_size
# 一维卷积的Toeplitz矩阵
T = torch.zeros(n, n)
for i in range(n):
for j in range(k):
if 0 <= i - j < n:
T[i, i-j] = kernel[j]
return T
# 二维情况:块Toeplitz矩阵
def create_block_toeplitz_2d(kernel_2d, H, W):
"""
创建二维Toeplitz矩阵(简化示意)
"""
k_h, k_w = kernel_2d.shape
n = H * W
T = torch.zeros(n, n)
for i in range(H):
for j in range(W):
for di in range(k_h):
for dj in range(k_w):
ni = i * W + j
nj = (i - di) * W + (j - dj)
if 0 <= nj < n and 0 <= (i - di) < H:
T[ni, nj] = kernel_2d[di, dj]
return T2. 循环神经网络与下三角矩阵
2.1 RNN的矩阵形式
标准RNN的前向传播可以表示为:
将所有时间步堆叠:
这是一个块下三角矩阵。
def rnn_as_matrix(X, W_hh, W_xh, b, h0=None):
"""
将RNN表示为矩阵乘法
"""
T, B, D_in = X.shape
D_h = W_hh.shape[0]
# 初始化
if h0 is None:
h = torch.zeros(B, D_h)
outputs = []
for t in range(T):
h = torch.tanh(X[t] @ W_xh.T + h @ W_hh.T + b)
outputs.append(h)
return torch.stack(outputs)
def rnn_matrix_form(X, W_hh, W_xh, b):
"""
RNN的显式矩阵形式
H = T @ X
其中 T 是块下三角矩阵
"""
T, B, D_in = X.shape
D_h = W_hh.shape[0]
# 构建块下三角矩阵(稀疏形式)
# 实际计算时使用循环而非显式构建
# 时间步展开
H_cols = []
for t in range(T):
# 从X的前t个时间步和历史hidden state计算
h_t = torch.zeros(B, D_h)
for s in range(t + 1):
# 贡献 = W_xh @ x_s
h_t = h_t + X[s] @ W_xh.T
H_cols.append(h_t)
return torch.stack(H_cols, dim=1)2.2 LSTM的门控机制
LSTM引入了门控机制,其矩阵表示更为复杂:
2.3 矩阵视角的洞察
def analyze_rnn_matrix_properties(W_hh, W_xh):
"""
分析RNN权重矩阵的性质
"""
# 谱半径
eigvals = torch.linalg.eigvals(W_hh)
spectral_radius = torch.max(torch.abs(eigvals)).item()
# 稳定性判据:谱半径 < 1 时,RNN动态系统稳定
is_stable = spectral_radius < 1.0
return {
'spectral_radius': spectral_radius,
'is_stable': is_stable,
'eigenvalue_distribution': eigvals.abs().sort().values
}3. Transformer与注意力矩阵
3.1 注意力的矩阵表示
标准Transformer注意力的计算:
从矩阵角度,这可以视为:
- 查询-键匹配:(对称矩阵)
- 归一化:(行随机矩阵)
- 值聚合:(加权聚合)
def attention_matrix_form(Q, K, V):
"""
注意力的矩阵形式
"""
# 相似度矩阵
S = Q @ K.transpose(-2, -1)
# 归一化(softmax)
A = F.softmax(S / math.sqrt(Q.shape[-1]), dim=-1)
# 加权聚合
O = A @ V
return O, A, S3.2 注意力矩阵的性质
def analyze_attention_matrix(attn_weights):
"""
分析注意力矩阵的数学性质
"""
# 假设 attn_weights: (seq, seq)
# 1. 行随机性:每行和为1
row_sums = attn_weights.sum(dim=-1)
is_row_stochastic = torch.allclose(row_sums, torch.ones_like(row_sums))
# 2. 谱性质:最大特征值 = 1
eigvals = torch.linalg.eigvalsh(attn_weights)
lambda_max = eigvals[-1].item()
# 3. 熵:衡量注意力的"锐利"程度
eps = 1e-10
entropy = -(attn_weights * torch.log(attn_weights + eps)).sum(dim=-1).mean()
# 4. 有效秩
U, S, _ = torch.linalg.svd(attn_weights, full_matrices=False)
S_norm = S / S.sum()
effective_rank = torch.exp(-(S_norm * torch.log(S_norm + eps)).sum())
return {
'is_row_stochastic': is_row_stochastic,
'lambda_max': lambda_max,
'entropy': entropy.item(),
'effective_rank': effective_rank.item()
}4. 统一框架:矩阵视角下的三大架构
4.1 架构对比
| 架构 | 矩阵类型 | 连接模式 | 复杂度 |
|---|---|---|---|
| CNN | Toeplitz/块Toeplitz | 局部连接 | |
| RNN | 块下三角 | 时间递归 | |
| Transformer | 全连接(动态加权) | 全局注意力 |
4.2 统一表示
定义一个统一的操作形式:
其中 是掩码矩阵,不同架构对应不同的 :
class UnifiedArchitecture(nn.Module):
"""
统一架构框架
"""
def __init__(self, architecture_type, d_model, d_hidden=None):
super().__init__()
self.architecture_type = architecture_type
self.d_model = d_model
if architecture_type == 'cnn':
# 卷积核大小
self.kernel_size = d_hidden or 3
self.conv = nn.Conv1d(d_model, d_model, self.kernel_size, padding=self.kernel_size//2)
elif architecture_type == 'rnn':
self.rnn = nn.RNN(d_model, d_model, batch_first=True)
elif architecture_type == 'transformer':
self.attn = nn.MultiheadAttention(d_model, 8, batch_first=True)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
if self.architecture_type == 'cnn':
# x: (batch, seq, d_model)
x = x.transpose(1, 2) # (batch, d_model, seq)
x = self.conv(x)
x = x.transpose(1, 2) # (batch, seq, d_model)
elif self.architecture_type == 'rnn':
x, _ = self.rnn(x)
elif self.architecture_type == 'transformer':
# 自注意力
attn_out, _ = self.attn(x, x, x, attn_mask=mask)
x = attn_out
return self.proj(x)4.3 从稀疏到全连接
不同架构可以被视为不同稀疏程度的注意力:
稀疏程度: 低 ────────────────────────────────────> 高
│ │
CNN: [1,1,1,0,0,0,0,0] 局部连接
[0,1,1,1,0,0,0,0] Toeplitz结构
[0,0,1,1,1,0,0,0]
RNN: [1,0,0,0,0,0,0,0] 递归连接
[1,1,0,0,0,0,0,0] 下三角结构
[1,1,1,0,0,0,0,0]
[1,1,1,1,0,0,0,0]
Transformer:[1,1,1,1,1,1,1,1] 全连接
[1,1,1,1,1,1,1,1] (动态权重)
[1,1,1,1,1,1,1,1]
5. 混合架构的矩阵视角
5.1 Hybrid CNN-Transformer
class HybridCNNTransformer(nn.Module):
"""
CNN + Transformer 混合架构
CNN处理局部特征 → Transformer建模全局依赖
"""
def __init__(self, d_model, n_heads, cnn_kernel=3):
super().__init__()
# 局部特征提取
self.conv = nn.Sequential(
nn.Conv1d(d_model, d_model, cnn_kernel, padding=cnn_kernel//2),
nn.GELU()
)
# 全局建模
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model, n_heads, batch_first=True),
num_layers=6
)
# 维度对齐
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
# CNN: 局部特征
x_local = x.transpose(1, 2) # (B, D, N)
x_local = self.conv(x_local)
x_local = x_local.transpose(1, 2) # (B, N, D)
# Transformer: 全局依赖
x_global = self.transformer(x_local)
# 残差连接
return self.norm(x + x_global)5.2 Hybrid RNN-Attention
class HybridRNNAttention(nn.Module):
"""
RNN + Attention 混合架构
"""
def __init__(self, d_model, n_heads):
super().__init__()
self.rnn = nn.LSTM(d_model, d_model, batch_first=True, bidirectional=True)
self.cross_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.norm = nn.LayerNorm(d_model)
def forward(self, x1, x2):
# x1, x2: 两个不同的输入序列
# RNN: 编码序列1
h = self.rnn(x1)[0]
# Cross Attention: 用x2的内容指导x1
out, _ = self.cross_attn(h, x2, x2)
return self.norm(h + out)6. 硬件感知的矩阵优化
6.1 矩阵结构与内存访问
def analyze_memory_pattern(architecture_type, seq_len, d_model):
"""
分析不同架构的内存访问模式
"""
if architecture_type == 'cnn':
# 卷积:局部内存访问
# 内存访问量 ~ O(n * k * d)
memory_access = seq_len * 3 * d_model # 输入、权重、输出
arithmetic_intensity = d_model / 3
elif architecture_type == 'rnn':
# RNN:时间步顺序访问
# 内存访问量 ~ O(n * d)
memory_access = seq_len * 4 * d_model # 输入、隐藏、权重、输出
arithmetic_intensity = d_model / 4
elif architecture_type == 'transformer':
# Transformer:全连接
# 内存访问量 ~ O(n² * d) (注意力部分)
memory_access = seq_len**2 * d_model + seq_len * d_model**2
arithmetic_intensity = seq_len / d_model
return {
'memory_access': memory_access,
'arithmetic_intensity': arithmetic_intensity,
'roofline_efficiency': min(arithmetic_intensity / 100, 1.0) # 假设peak=100
}6.2 稀疏矩阵乘法加速
def sparse_attention_forward(Q, K, V, sparsity_pattern):
"""
利用稀疏性加速注意力计算
sparsity_pattern: 指示哪些位置需要计算的掩码
"""
# 只计算非稀疏位置的注意力
Q_sparse = Q[:, sparsity_pattern]
K_sparse = K[:, sparsity_pattern]
# 批量稀疏矩阵乘法
S_sparse = torch.bmm(Q_sparse, K_sparse.transpose(-2, -1))
# 填充回完整矩阵
S = torch.zeros(Q.shape[0], seq_len, seq_len)
S[sparsity_pattern] = S_sparse
# softmax和加权求和
A = F.softmax(S / math.sqrt(Q.shape[-1]), dim=-1)
O = torch.bmm(A, V)
return O7. 理论洞察
7.1 表示能力边界
CNN:由于Toeplitz矩阵的局部性,CNN的表达能力受限于感受野大小。
RNN:下三角矩阵结构意味着信息必须通过时间步逐步传递,长距离依赖的表示能力受限。
Transformer:全连接注意力理论上可以建模任意位置间的依赖,但计算复杂度为 。
7.2 归纳偏置的矩阵解释
| 架构 | 归纳偏置 | 矩阵结构 | 适用场景 |
|---|---|---|---|
| CNN | 局部性、平移不变性 | Toeplitz | 图像、语音 |
| RNN | 时间序列性、顺序性 | 下三角 | 时间序列、NLP |
| Transformer | 完全连接性 | 全连接 | 通用建模 |
8. 总结
核心要点
- CNN = Toeplitz矩阵乘法:稀疏、局部连接、平移不变性
- RNN = 下三角矩阵乘法:递归连接、时间依赖
- Transformer = 全连接注意力矩阵:动态权重、全局依赖
- 混合架构可以组合不同矩阵结构的优势
统一框架
所有神经网络操作都可以视为:
其中 的结构决定了架构的类型和特性。
参考资料
相关链接:
Footnotes
-
Unified Matrix Framework. (2025). A Unified Matrix Framework for Neural Architectures. arXiv:2506.01966. ↩