矩阵微积分与深度学习
深度学习的核心是优化海量参数以最小化损失函数。理解矩阵微积分是掌握反向传播算法和神经网络训练机制的关键。本章将从矩阵微分基础出发,系统推导神经网络中的梯度计算。
矩阵微分基础
符号约定
在神经网络中,我们使用以下约定:
| 符号 | 含义 | 形状 |
|---|---|---|
| 输入矩阵 | ||
| 权重矩阵 | ||
| 输出矩阵 | ||
| 损失函数 | 标量 |
梯度定义
对于标量函数 ,梯度定义为:
注意:梯度与原矩阵形状相同。
雅可比矩阵
对于向量函数 ,雅可比矩阵为:
链式法则
深度学习中最核心的工具是链式法则。设 ,,则:
其中乘积为矩阵乘法。
神经网络层求导
全连接层
全连接层前向传播:
对输入 求导:
对权重 求导:
对偏置 求导:
其中 沿行维度求和。
逐元素激活函数
设 ,其中 逐元素作用。
其中 为哈达玛积(元素对应乘积)。
Softmax层
对于多分类问题,Softmax输出为:
交叉熵损失下的梯度:
设真实标签为 (one-hot编码),交叉熵损失为 ,则:
这个简洁的结论是深度学习中最重要的公式之一。1
矩阵乘法的梯度汇总
对于损失函数 ,以下矩阵恒等式在反向传播中反复出现:
| 前向传播 | ||
|---|---|---|
其中 是下游梯度。
反向传播算法矩阵形式
两层全连接网络
考虑最简单的两层网络:
反向传播推导:
Step 1:输出层梯度
Step 2:第二层权重梯度
Step 3:传播到隐藏层
Step 4:第一层权重梯度
批量处理的矩阵形式
设批量大小为 ,输入维度为 ,隐藏维度为 ,输出维度为 :
反向传播时,所有样本的梯度可以批量计算,充分利用GPU的并行能力。
PyTorch实现
import torch
import torch.nn as nn
class TwoLayerNet(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 手动验证梯度计算
model = TwoLayerNet(784, 256, 10)
x = torch.randn(32, 784) # batch_size=32
y = torch.randn(32, 10)
output = model(x)
loss = torch.mean((output - y) ** 2) # MSE损失
loss.backward()
# 检查梯度
for name, param in model.named_parameters():
print(f"{name}: grad.shape = {param.grad.shape}")自动微分原理
计算图
自动微分的核心是构建计算图(Computational Graph)。每个节点表示一个变量,每条边表示一个运算。
前向传播:
反向传播:
PyTorch autograd机制
PyTorch使用反向模式自动微分(Reverse Mode AD),适合计算图输入多、输出少的场景(如神经网络的单个损失值)。
import torch
# 创建需要梯度的张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = torch.tensor([4.0, 5.0, 6.0], requires_grad=True)
# 构建计算图
z = (x * 2 + y ** 2).sum()
z.backward()
# 自动得到梯度
print(x.grad) # dz/dx = 2
print(y.grad) # dz/dy = 2*y = [8, 10, 12]JAX vs PyTorch
| 特性 | PyTorch | JAX |
|---|---|---|
| 梯度函数 | loss.backward() | jax.grad(loss_fn) |
| 函数式 | 命令式 | 函数式 |
| 即时编译 | torch.compile | jax.jit |
| 纯函数 | 不要求 | 严格要求 |
# JAX版本
import jax
import jax.numpy as jnp
def loss_fn(params, x, y):
w1, b1, w2, b2 = params
h = jnp.maximum(0, x @ w1 + b1) # ReLU
return jnp.mean((h @ w2 + b2 - y) ** 2)
# 自动梯度
grad_fn = jax.grad(loss_fn)
grads = grad_fn(params, x, y)雅可比矩阵的PyTorch计算
# 计算完整雅可比矩阵
x = torch.randn(3, requires_grad=True)
y = x ** 2 # 向量输出
# 方法1:逐元素backward
jacobian = torch.zeros(3, 3)
for i in range(3):
y[i].backward(retain_graph=True)
jacobian[i] = x.grad
x.grad.zero_()
# 方法2:使用torch.autograd.functional.jacobian
from torch.autograd.functional import jacobian
jacobian = jacobian(lambda x: x ** 2, x)梯度问题与数值稳定性
梯度消失与爆炸
在深层网络中,梯度是多个雅可比矩阵的连乘:
梯度消失:当 时,梯度指数衰减。
梯度爆炸:当 时,梯度指数增长。
梯度裁剪
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)数值精度问题
# 避免log(0)导致-inf
log_prob = torch.log_softmax(logits, dim=-1)
# 等价于在softmax内部进行数值稳定化处理高级主题:二阶优化
Hessian矩阵
Hessian矩阵是损失函数的二阶导数:
牛顿法更新:
自然梯度
自然梯度考虑了参数空间黎曼几何:
其中 是Fisher信息矩阵。
参考
Footnotes
-
Goodfellow, Bengio, Courville. “Deep Learning”. MIT Press, 2016. Chapter 6. ↩