自动微分数学基础
自动微分(Automatic Differentiation,AD)是现代深度学习系统的核心技术。与数值微分和符号微分不同,自动微分能够以机器精度(machine precision)计算任意程序的导数,同时保持接近最优的计算效率。
1. 微分方法对比
1.1 三种微分方法的本质区别
| 方法 | 原理 | 精度 | 计算复杂度 | 表达能力 |
|---|---|---|---|---|
| 数值微分 | 有限差分近似 | per direction | 任意可求值函数 | |
| 符号微分 | 解析表达式操作 | 精确 | 表达式膨胀 | 闭式表达式 |
| 自动微分 | 链式法则分解 | 精确 | 最优 | 任意程序 |
1.2 数值微分的局限性
前向差分近似:
中心差分近似:
问题:当 过小时,数值误差(舍入误差)主导;当 过大时,截断误差主导。最优 的选择困难,且精度受限于 或 。
1.3 符号微分的局限性
符号微分通过操作数学表达式计算导数:
# 符号微分示例:d/dx(sin(x) * exp(x))
# = cos(x) * exp(x) + sin(x) * exp(x)
# = exp(x) * (cos(x) + sin(x))问题:
- 表达式膨胀(expression swell):简单函数可能产生指数级复杂的导数表达式
- 无法处理分支、循环等程序控制流
- 对含有中间变量的函数效率低下
1.4 自动微分的核心思想
自动微分通过程序分解将任意复合函数拆分为基本运算的组合,然后应用链式法则:
关键特性:
- 计算精度与符号微分相同(机器精度)
- 计算效率与最优算法相同
- 可以处理任意程序结构
2. 数学基础回顾
2.1 雅可比矩阵
对于函数 ,雅可比矩阵 定义为:
其中 ,。
特殊情况:
- 当 时,(梯度向量)
- 当 时,(列向量)
2.2 链式法则
标量形式:
向量形式(雅可比链式法则):
2.3 链式法则的矩阵视角
设 ,,,,复合关系为 ,则:
3. 正向模式自动微分
3.1 对偶数方法
对偶数(Dual Numbers)是实现正向模式自动微分的数学工具:
其中:
- :原始值(primal value)
- :导数(derivative/tangent)
代数运算规则:
| 运算 | 规则 |
|---|---|
| 加法 | |
| 减法 | |
| 乘法 | |
| 除法 | |
| 幂函数 | |
| 指数 | |
| 对数 | |
| 正弦 | |
| 余弦 |
示例:计算 在 处的导数
因此 ,。
3.2 正切传播
正切传播(Tangent Propagation)是对偶数方法的另一种视角。通过扩展跟踪的扰动方向 ,可以计算任意方向的导数。
设输入扰动为 ,输出扰动为 ,则:
3.3 雅可比-向量积
正向模式的核心运算是雅可比-向量积(Jacobian-Vector Product,JVP):
其中 是输入扰动方向。
计算复杂度:
- 朴素的雅可比矩阵计算:
- 雅可比-向量积:
3.4 正向模式效率分析
定理:对于函数 ,正向模式自动微分的计算复杂度为 。
选择准则:
- 当 (多输出函数,少量输入)时,正向模式高效
- 当 (单输出函数,大量输入,如神经网络训练)时,正向模式低效
| 场景 | 推荐模式 | 原因 |
|---|---|---|
| 输入维度 输出维度 | 正向模式 | 一次前向计算即可获得所有导数 |
| 输入维度 输出维度 | 反向模式 | 一次反向计算即可获得所有导数 |
| 输入维度 输出维度 | 任一均可 | 复杂度相近 |
4. 反向模式自动微分
4.1 反向模式核心思想
反向模式(Reverse Mode)自动微分通过两阶段算法计算导数:
- 前向阶段:构建计算图,存储中间变量
- 反向阶段:从输出向输入反向传播梯度
设复合函数 ,定义中间变量:
链式法则展开:
其中 称为伴随变量(adjoint variable)。
4.2 伴随传播公式
对于任意中间变量 ,其伴随变量定义为:
反向传播算法:
- 初始化
- 从输出向输入反向遍历:
4.3 雅可比转置-向量积
反向模式的核心运算是雅可比转置-向量积(Jacobian-Transpose-Vector Product,JTVP):
4.4 反向模式示例
问题:计算 的梯度 。
前向计算:
反向计算:
4.5 反向模式效率分析
定理:对于函数 (如神经网络损失函数),反向模式自动微分的计算复杂度为 。
对比:
| 维度关系 | 正向模式复杂度 | 反向模式复杂度 |
|---|---|---|
| ✅ | ||
| ✅ |
这就是为什么反向模式(反向传播)是神经网络训练的标准方法。
5. 元素基本函数库
自动微分系统将函数分解为一组元素基本函数(Elementary Primitives)的组合。
5.1 基本算术运算
| 函数 | 导数 |
|---|---|
| (常数) | |
5.2 超越函数
| 函数 | 导数 |
|---|---|
5.3 神经网络常用函数
| 函数 | 正向 | 反向 |
|---|---|---|
| Sigmoid | ||
| if else | if else | |
| if else | if else | |
| 详见Softmax导数 |
6. 实现方法分类
6.1 操作符重载(Operator Overloading)
通过重载基本运算的数据类型,在运行时构建计算图。
优点:
- 易于实现,用户透明
- 支持动态计算图
- 可以处理任意Python代码
缺点:
- 需要特殊的微分数据类型(如对偶数)
- 运行时开销
- 某些Python特性难以支持
代表框架:PyTorch Autograd,MxNet Gluon
6.2 源码转换(Source Transformation)
在编译或解释前将微分代码转换为显式梯度代码。
优点:
- 无运行时开销
- 可进行图优化
- 支持任意精度
缺点:
- 实现复杂
- 需要处理完整的语言子集
- 调试困难
代表框架:JAX(函数式变换),TensorFlow XLA(编译优化)
6.3 追踪/雾化(Tracing)
执行函数一次,记录操作序列,然后构建计算图。
优点:
- 实现相对简单
- 可与源码转换结合
缺点:
- 仅记录实际执行的路径
- 可能丢失控制流信息
代表框架:PyTorch torch.jit.trace,JAX jax.make_jaxpr
7. 实践示例
7.1 PyTorch自动微分
import torch
# 创建requires_grad=True的张量
x = torch.tensor([1.0, 2.0], requires_grad=True)
y = torch.tensor([3.0, 4.0], requires_grad=True)
# 定义计算
z = x ** 2 + torch.sin(y) + x * y
# 计算梯度
z.sum().backward()
print(f"z = {z}")
print(f"dz/dx = {x.grad}") # ∂z/∂x = 2x + y
print(f"dz/dy = {y.grad}") # ∂z/∂y = cos(y) + x7.2 JAX函数式微分
import jax
import jax.numpy as jnp
def f(x, y):
return jnp.sum(x ** 2 + jnp.sin(y) + x * y)
# 标量梯度
grad_f = jax.grad(f)
x = jnp.array([1.0, 2.0])
y = jnp.array([3.0, 4.0])
grads = grad_f(x, y)
print(f"∂f/∂x = {grads}") # 返回对第一个参数的梯度7.3 高阶导数
import jax.numpy as jnp
import jax
def f(x):
return x ** 3 + 2 * x ** 2
# 一阶导数: 3x² + 4x
grad_f = jax.grad(f)
# 二阶导数: 6x + 4
hess_f = jax.jacfwd(jax.grad(f))
# 或使用 jax.hessian (等价于 jacfwd(jacrev))
hess_f_v2 = jax.hessian(f)
x = 2.0
print(f"f'(x) = {grad_f(x)}") # 20.0
print(f"f''(x) = {hess_f(x)}") # 16.08. 数学正确性保证
8.1 自动微分正确性定理
定理:对于任意由基本初等函数组成的复合函数,自动微分算法计算的导数等于解析导数(假设浮点运算精确)。
证明概要:
- 基本初等函数的导数定义正确
- 链式法则适用于雅可比矩阵乘法
- 归纳法证明整个复合函数的导数正确
8.2 数值稳定性
自动微分使用与前向计算相同的浮点运算,因此:
- 无截断误差(与符号微分相同)
- 存在舍入误差(但与前向计算量级相同)
- 精度受限于浮点表示(通常为 数量级)
8.3 梯度校验
使用数值微分验证自动微分实现的正确性:
def gradient_check(f, x, eps=1e-7):
"""使用中心差分验证梯度"""
grad_numerical = []
grad_autodiff = jax.grad(f)(x)
for i in range(len(x)):
x_plus = x.at[i].add(eps)
x_minus = x.at[i].sub(eps)
grad_numerical.append((f(x_plus) - f(x_minus)) / (2 * eps))
return jnp.allclose(grad_numerical, grad_autodiff, rtol=1e-5)9. 总结
9.1 核心要点
- 自动微分通过链式法则将复合函数分解为基本运算,实现精确高效的导数计算
- 正向模式适用于输入维度小、输出维度大的场景
- 反向模式适用于输入维度大(神经网络参数)、输出维度小(单损失值)的场景
- 对偶数提供了一种优雅的正向模式实现方式
- 自动微分的精度是机器精度,不存在数值微分的截断误差
9.2 后续内容
- 反向模式详解:深入对比两种模式的数学原理
- 计算图表示与执行:计算图的构建与优化
- PyTorch Autograd实现:框架内部机制
- JAX自动微分框架:函数式微分变换