神经网络架构的统一矩阵框架

引言

深度学习领域涌现了大量专用架构:CNN处理图像、RNN处理序列、Transformer处理长距离依赖。从线性代数的视角,这些看似不同的架构实际上都是矩阵乘法的不同变体,它们之间的差异主要体现在矩阵的结构特性上。

核心观点:所有主流神经网络架构都可以用稀疏矩阵乘法来统一描述。卷积是Toeplitz矩阵乘法,循环是下三角矩阵乘法,注意力是全连接矩阵乘法。1


1. 从卷积到矩阵乘法

1.1 卷积的矩阵表示

二维卷积可以通过Toeplitz矩阵(块Toeplitz矩阵)表示为矩阵乘法:

对于输入 和卷积核

其中 是一个块Toeplitz矩阵 的向量化形式。

1.2 im2col:卷积到矩阵乘法的转换

深度学习框架(如cuDNN)将卷积转换为矩阵乘法来利用GPU的矩阵运算优化:

def im2col(X, kernel_size, stride, padding):
    """
    im2col:将卷积转换为矩阵乘法
    
    思想:将每个滑动窗口展开为矩阵的一列
    """
    B, C, H, W = X.shape
    
    # 填充
    if padding > 0:
        X_padded = F.pad(X, (padding, padding, padding, padding))
    else:
        X_padded = X
    
    # 计算输出尺寸
    H_out = (H + 2*padding - kernel_size) // stride + 1
    W_out = (W + 2*padding - kernel_size) // stride + 1
    
    # 提取滑动窗口
    cols = F.unfold(X_padded, kernel_size, stride=stride)
    cols = cols.transpose(1, 2)  # (B, H_out*W_out, C*k*k)
    
    return cols, H_out, W_out
 
 
def conv_as_matmul(X, weight, stride=1, padding=1):
    """
    将卷积实现为矩阵乘法
    """
    B, C, H, W = X.shape
    k, _, C_out, _ = weight.shape
    
    # im2col转换
    cols, H_out, W_out = im2col(X, k, stride, padding)
    
    # 展平卷积核
    weight_flat = weight.view(C_out, -1).T  # (C*k*k, C_out)
    
    # 矩阵乘法
    out = cols @ weight_flat  # (B, H_out*W_out, C_out)
    
    # 重塑为图像格式
    out = out.view(B, H_out, W_out, C_out).permute(0, 3, 1, 2)
    
    return out

1.3 Toeplitz矩阵结构

def create_toeplitz_matrix(kernel, input_size):
    """
    创建Toeplitz矩阵
    
    每一行是卷积核的循环移位
    """
    k = len(kernel)
    n = input_size
    
    # 一维卷积的Toeplitz矩阵
    T = torch.zeros(n, n)
    
    for i in range(n):
        for j in range(k):
            if 0 <= i - j < n:
                T[i, i-j] = kernel[j]
    
    return T
 
 
# 二维情况:块Toeplitz矩阵
def create_block_toeplitz_2d(kernel_2d, H, W):
    """
    创建二维Toeplitz矩阵(简化示意)
    """
    k_h, k_w = kernel_2d.shape
    n = H * W
    
    T = torch.zeros(n, n)
    
    for i in range(H):
        for j in range(W):
            for di in range(k_h):
                for dj in range(k_w):
                    ni = i * W + j
                    nj = (i - di) * W + (j - dj)
                    if 0 <= nj < n and 0 <= (i - di) < H:
                        T[ni, nj] = kernel_2d[di, dj]
    
    return T

2. 循环神经网络与下三角矩阵

2.1 RNN的矩阵形式

标准RNN的前向传播可以表示为:

将所有时间步堆叠:

这是一个块下三角矩阵

def rnn_as_matrix(X, W_hh, W_xh, b, h0=None):
    """
    将RNN表示为矩阵乘法
    """
    T, B, D_in = X.shape
    D_h = W_hh.shape[0]
    
    # 初始化
    if h0 is None:
        h = torch.zeros(B, D_h)
    
    outputs = []
    
    for t in range(T):
        h = torch.tanh(X[t] @ W_xh.T + h @ W_hh.T + b)
        outputs.append(h)
    
    return torch.stack(outputs)
 
 
def rnn_matrix_form(X, W_hh, W_xh, b):
    """
    RNN的显式矩阵形式
    
    H = T @ X
    其中 T 是块下三角矩阵
    """
    T, B, D_in = X.shape
    D_h = W_hh.shape[0]
    
    # 构建块下三角矩阵(稀疏形式)
    # 实际计算时使用循环而非显式构建
    
    # 时间步展开
    H_cols = []
    for t in range(T):
        # 从X的前t个时间步和历史hidden state计算
        h_t = torch.zeros(B, D_h)
        for s in range(t + 1):
            # 贡献 = W_xh @ x_s
            h_t = h_t + X[s] @ W_xh.T
        H_cols.append(h_t)
    
    return torch.stack(H_cols, dim=1)

2.2 LSTM的门控机制

LSTM引入了门控机制,其矩阵表示更为复杂:

2.3 矩阵视角的洞察

def analyze_rnn_matrix_properties(W_hh, W_xh):
    """
    分析RNN权重矩阵的性质
    """
    # 谱半径
    eigvals = torch.linalg.eigvals(W_hh)
    spectral_radius = torch.max(torch.abs(eigvals)).item()
    
    # 稳定性判据:谱半径 < 1 时,RNN动态系统稳定
    is_stable = spectral_radius < 1.0
    
    return {
        'spectral_radius': spectral_radius,
        'is_stable': is_stable,
        'eigenvalue_distribution': eigvals.abs().sort().values
    }

3. Transformer与注意力矩阵

3.1 注意力的矩阵表示

标准Transformer注意力的计算:

从矩阵角度,这可以视为:

  1. 查询-键匹配(对称矩阵)
  2. 归一化(行随机矩阵)
  3. 值聚合(加权聚合)
def attention_matrix_form(Q, K, V):
    """
    注意力的矩阵形式
    """
    # 相似度矩阵
    S = Q @ K.transpose(-2, -1)
    
    # 归一化(softmax)
    A = F.softmax(S / math.sqrt(Q.shape[-1]), dim=-1)
    
    # 加权聚合
    O = A @ V
    
    return O, A, S

3.2 注意力矩阵的性质

def analyze_attention_matrix(attn_weights):
    """
    分析注意力矩阵的数学性质
    """
    # 假设 attn_weights: (seq, seq)
    
    # 1. 行随机性:每行和为1
    row_sums = attn_weights.sum(dim=-1)
    is_row_stochastic = torch.allclose(row_sums, torch.ones_like(row_sums))
    
    # 2. 谱性质:最大特征值 = 1
    eigvals = torch.linalg.eigvalsh(attn_weights)
    lambda_max = eigvals[-1].item()
    
    # 3. 熵:衡量注意力的"锐利"程度
    eps = 1e-10
    entropy = -(attn_weights * torch.log(attn_weights + eps)).sum(dim=-1).mean()
    
    # 4. 有效秩
    U, S, _ = torch.linalg.svd(attn_weights, full_matrices=False)
    S_norm = S / S.sum()
    effective_rank = torch.exp(-(S_norm * torch.log(S_norm + eps)).sum())
    
    return {
        'is_row_stochastic': is_row_stochastic,
        'lambda_max': lambda_max,
        'entropy': entropy.item(),
        'effective_rank': effective_rank.item()
    }

4. 统一框架:矩阵视角下的三大架构

4.1 架构对比

架构矩阵类型连接模式复杂度
CNNToeplitz/块Toeplitz局部连接
RNN块下三角时间递归
Transformer全连接(动态加权)全局注意力

4.2 统一表示

定义一个统一的操作形式:

其中 是掩码矩阵,不同架构对应不同的

class UnifiedArchitecture(nn.Module):
    """
    统一架构框架
    """
    def __init__(self, architecture_type, d_model, d_hidden=None):
        super().__init__()
        self.architecture_type = architecture_type
        self.d_model = d_model
        
        if architecture_type == 'cnn':
            # 卷积核大小
            self.kernel_size = d_hidden or 3
            self.conv = nn.Conv1d(d_model, d_model, self.kernel_size, padding=self.kernel_size//2)
        elif architecture_type == 'rnn':
            self.rnn = nn.RNN(d_model, d_model, batch_first=True)
        elif architecture_type == 'transformer':
            self.attn = nn.MultiheadAttention(d_model, 8, batch_first=True)
        
        self.proj = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        if self.architecture_type == 'cnn':
            # x: (batch, seq, d_model)
            x = x.transpose(1, 2)  # (batch, d_model, seq)
            x = self.conv(x)
            x = x.transpose(1, 2)  # (batch, seq, d_model)
        
        elif self.architecture_type == 'rnn':
            x, _ = self.rnn(x)
        
        elif self.architecture_type == 'transformer':
            # 自注意力
            attn_out, _ = self.attn(x, x, x, attn_mask=mask)
            x = attn_out
        
        return self.proj(x)

4.3 从稀疏到全连接

不同架构可以被视为不同稀疏程度的注意力

稀疏程度:  低 ────────────────────────────────────> 高
            │                                        │
CNN:        [1,1,1,0,0,0,0,0]                      局部连接
            [0,1,1,1,0,0,0,0]                      Toeplitz结构
            [0,0,1,1,1,0,0,0]
            
RNN:        [1,0,0,0,0,0,0,0]                      递归连接
            [1,1,0,0,0,0,0,0]                      下三角结构
            [1,1,1,0,0,0,0,0]
            [1,1,1,1,0,0,0,0]
            
Transformer:[1,1,1,1,1,1,1,1]                      全连接
            [1,1,1,1,1,1,1,1]                      (动态权重)
            [1,1,1,1,1,1,1,1]

5. 混合架构的矩阵视角

5.1 Hybrid CNN-Transformer

class HybridCNNTransformer(nn.Module):
    """
    CNN + Transformer 混合架构
    
    CNN处理局部特征 → Transformer建模全局依赖
    """
    def __init__(self, d_model, n_heads, cnn_kernel=3):
        super().__init__()
        
        # 局部特征提取
        self.conv = nn.Sequential(
            nn.Conv1d(d_model, d_model, cnn_kernel, padding=cnn_kernel//2),
            nn.GELU()
        )
        
        # 全局建模
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, n_heads, batch_first=True),
            num_layers=6
        )
        
        # 维度对齐
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        # CNN: 局部特征
        x_local = x.transpose(1, 2)  # (B, D, N)
        x_local = self.conv(x_local)
        x_local = x_local.transpose(1, 2)  # (B, N, D)
        
        # Transformer: 全局依赖
        x_global = self.transformer(x_local)
        
        # 残差连接
        return self.norm(x + x_global)

5.2 Hybrid RNN-Attention

class HybridRNNAttention(nn.Module):
    """
    RNN + Attention 混合架构
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.rnn = nn.LSTM(d_model, d_model, batch_first=True, bidirectional=True)
        self.cross_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x1, x2):
        # x1, x2: 两个不同的输入序列
        
        # RNN: 编码序列1
        h = self.rnn(x1)[0]
        
        # Cross Attention: 用x2的内容指导x1
        out, _ = self.cross_attn(h, x2, x2)
        
        return self.norm(h + out)

6. 硬件感知的矩阵优化

6.1 矩阵结构与内存访问

def analyze_memory_pattern(architecture_type, seq_len, d_model):
    """
    分析不同架构的内存访问模式
    """
    if architecture_type == 'cnn':
        # 卷积:局部内存访问
        # 内存访问量 ~ O(n * k * d)
        memory_access = seq_len * 3 * d_model  # 输入、权重、输出
        arithmetic_intensity = d_model / 3
    
    elif architecture_type == 'rnn':
        # RNN:时间步顺序访问
        # 内存访问量 ~ O(n * d)
        memory_access = seq_len * 4 * d_model  # 输入、隐藏、权重、输出
        arithmetic_intensity = d_model / 4
    
    elif architecture_type == 'transformer':
        # Transformer:全连接
        # 内存访问量 ~ O(n² * d) (注意力部分)
        memory_access = seq_len**2 * d_model + seq_len * d_model**2
        arithmetic_intensity = seq_len / d_model
    
    return {
        'memory_access': memory_access,
        'arithmetic_intensity': arithmetic_intensity,
        'roofline_efficiency': min(arithmetic_intensity / 100, 1.0)  # 假设peak=100
    }

6.2 稀疏矩阵乘法加速

def sparse_attention_forward(Q, K, V, sparsity_pattern):
    """
    利用稀疏性加速注意力计算
    
    sparsity_pattern: 指示哪些位置需要计算的掩码
    """
    # 只计算非稀疏位置的注意力
    Q_sparse = Q[:, sparsity_pattern]
    K_sparse = K[:, sparsity_pattern]
    
    # 批量稀疏矩阵乘法
    S_sparse = torch.bmm(Q_sparse, K_sparse.transpose(-2, -1))
    
    # 填充回完整矩阵
    S = torch.zeros(Q.shape[0], seq_len, seq_len)
    S[sparsity_pattern] = S_sparse
    
    # softmax和加权求和
    A = F.softmax(S / math.sqrt(Q.shape[-1]), dim=-1)
    O = torch.bmm(A, V)
    
    return O

7. 理论洞察

7.1 表示能力边界

CNN:由于Toeplitz矩阵的局部性,CNN的表达能力受限于感受野大小。

RNN:下三角矩阵结构意味着信息必须通过时间步逐步传递,长距离依赖的表示能力受限。

Transformer:全连接注意力理论上可以建模任意位置间的依赖,但计算复杂度为

7.2 归纳偏置的矩阵解释

架构归纳偏置矩阵结构适用场景
CNN局部性、平移不变性Toeplitz图像、语音
RNN时间序列性、顺序性下三角时间序列、NLP
Transformer完全连接性全连接通用建模

8. 总结

核心要点

  1. CNN = Toeplitz矩阵乘法:稀疏、局部连接、平移不变性
  2. RNN = 下三角矩阵乘法:递归连接、时间依赖
  3. Transformer = 全连接注意力矩阵:动态权重、全局依赖
  4. 混合架构可以组合不同矩阵结构的优势

统一框架

所有神经网络操作都可以视为:

其中 的结构决定了架构的类型和特性。


参考资料


相关链接

Footnotes

  1. Unified Matrix Framework. (2025). A Unified Matrix Framework for Neural Architectures. arXiv:2506.01966.