自动微分数学基础

自动微分(Automatic Differentiation,AD)是现代深度学习系统的核心技术。与数值微分和符号微分不同,自动微分能够以机器精度(machine precision)计算任意程序的导数,同时保持接近最优的计算效率。

1. 微分方法对比

1.1 三种微分方法的本质区别

方法原理精度计算复杂度表达能力
数值微分有限差分近似 per direction任意可求值函数
符号微分解析表达式操作精确表达式膨胀闭式表达式
自动微分链式法则分解精确最优任意程序

1.2 数值微分的局限性

前向差分近似:

中心差分近似:

问题:当 过小时,数值误差(舍入误差)主导;当 过大时,截断误差主导。最优 的选择困难,且精度受限于

1.3 符号微分的局限性

符号微分通过操作数学表达式计算导数:

# 符号微分示例:d/dx(sin(x) * exp(x))
# = cos(x) * exp(x) + sin(x) * exp(x)
# = exp(x) * (cos(x) + sin(x))

问题

  • 表达式膨胀(expression swell):简单函数可能产生指数级复杂的导数表达式
  • 无法处理分支、循环等程序控制流
  • 对含有中间变量的函数效率低下

1.4 自动微分的核心思想

自动微分通过程序分解将任意复合函数拆分为基本运算的组合,然后应用链式法则:

关键特性

  • 计算精度与符号微分相同(机器精度)
  • 计算效率与最优算法相同
  • 可以处理任意程序结构

2. 数学基础回顾

2.1 雅可比矩阵

对于函数 ,雅可比矩阵 定义为:

其中

特殊情况

  • 时,(梯度向量)
  • 时,(列向量)

2.2 链式法则

标量形式

向量形式(雅可比链式法则):

2.3 链式法则的矩阵视角

,复合关系为 ,则:

3. 正向模式自动微分

3.1 对偶数方法

对偶数(Dual Numbers)是实现正向模式自动微分的数学工具:

其中:

  • :原始值(primal value)
  • :导数(derivative/tangent)

代数运算规则

运算规则
加法
减法
乘法
除法
幂函数
指数
对数
正弦
余弦

示例:计算 处的导数

因此

3.2 正切传播

正切传播(Tangent Propagation)是对偶数方法的另一种视角。通过扩展跟踪的扰动方向 ,可以计算任意方向的导数。

设输入扰动为 ,输出扰动为 ,则:

3.3 雅可比-向量积

正向模式的核心运算是雅可比-向量积(Jacobian-Vector Product,JVP):

其中 是输入扰动方向。

计算复杂度

  • 朴素的雅可比矩阵计算:
  • 雅可比-向量积:

3.4 正向模式效率分析

定理:对于函数 ,正向模式自动微分的计算复杂度为

选择准则

  • (多输出函数,少量输入)时,正向模式高效
  • (单输出函数,大量输入,如神经网络训练)时,正向模式低效
场景推荐模式原因
输入维度 输出维度正向模式一次前向计算即可获得所有导数
输入维度 输出维度反向模式一次反向计算即可获得所有导数
输入维度 输出维度任一均可复杂度相近

4. 反向模式自动微分

4.1 反向模式核心思想

反向模式(Reverse Mode)自动微分通过两阶段算法计算导数:

  1. 前向阶段:构建计算图,存储中间变量
  2. 反向阶段:从输出向输入反向传播梯度

设复合函数 ,定义中间变量:

链式法则展开:

其中 称为伴随变量(adjoint variable)。

4.2 伴随传播公式

对于任意中间变量 ,其伴随变量定义为:

反向传播算法

  1. 初始化
  2. 从输出向输入反向遍历:

4.3 雅可比转置-向量积

反向模式的核心运算是雅可比转置-向量积(Jacobian-Transpose-Vector Product,JTVP):

4.4 反向模式示例

问题:计算 的梯度

前向计算

反向计算




4.5 反向模式效率分析

定理:对于函数 (如神经网络损失函数),反向模式自动微分的计算复杂度为

对比

维度关系正向模式复杂度反向模式复杂度

这就是为什么反向模式(反向传播)是神经网络训练的标准方法。

5. 元素基本函数库

自动微分系统将函数分解为一组元素基本函数(Elementary Primitives)的组合。

5.1 基本算术运算

函数导数
(常数)

5.2 超越函数

函数导数

5.3 神经网络常用函数

函数正向反向
Sigmoid
if else if else
if else if else
详见Softmax导数

6. 实现方法分类

6.1 操作符重载(Operator Overloading)

通过重载基本运算的数据类型,在运行时构建计算图。

优点

  • 易于实现,用户透明
  • 支持动态计算图
  • 可以处理任意Python代码

缺点

  • 需要特殊的微分数据类型(如对偶数)
  • 运行时开销
  • 某些Python特性难以支持

代表框架:PyTorch Autograd,MxNet Gluon

6.2 源码转换(Source Transformation)

在编译或解释前将微分代码转换为显式梯度代码。

优点

  • 无运行时开销
  • 可进行图优化
  • 支持任意精度

缺点

  • 实现复杂
  • 需要处理完整的语言子集
  • 调试困难

代表框架:JAX(函数式变换),TensorFlow XLA(编译优化)

6.3 追踪/雾化(Tracing)

执行函数一次,记录操作序列,然后构建计算图。

优点

  • 实现相对简单
  • 可与源码转换结合

缺点

  • 仅记录实际执行的路径
  • 可能丢失控制流信息

代表框架:PyTorch torch.jit.trace,JAX jax.make_jaxpr

7. 实践示例

7.1 PyTorch自动微分

import torch
 
# 创建requires_grad=True的张量
x = torch.tensor([1.0, 2.0], requires_grad=True)
y = torch.tensor([3.0, 4.0], requires_grad=True)
 
# 定义计算
z = x ** 2 + torch.sin(y) + x * y
 
# 计算梯度
z.sum().backward()
 
print(f"z = {z}")
print(f"dz/dx = {x.grad}")  # ∂z/∂x = 2x + y
print(f"dz/dy = {y.grad}")  # ∂z/∂y = cos(y) + x

7.2 JAX函数式微分

import jax
import jax.numpy as jnp
 
def f(x, y):
    return jnp.sum(x ** 2 + jnp.sin(y) + x * y)
 
# 标量梯度
grad_f = jax.grad(f)
 
x = jnp.array([1.0, 2.0])
y = jnp.array([3.0, 4.0])
 
grads = grad_f(x, y)
print(f"∂f/∂x = {grads}")  # 返回对第一个参数的梯度

7.3 高阶导数

import jax.numpy as jnp
import jax
 
def f(x):
    return x ** 3 + 2 * x ** 2
 
# 一阶导数: 3x² + 4x
grad_f = jax.grad(f)
 
# 二阶导数: 6x + 4
hess_f = jax.jacfwd(jax.grad(f))
 
# 或使用 jax.hessian (等价于 jacfwd(jacrev))
hess_f_v2 = jax.hessian(f)
 
x = 2.0
print(f"f'(x) = {grad_f(x)}")  # 20.0
print(f"f''(x) = {hess_f(x)}") # 16.0

8. 数学正确性保证

8.1 自动微分正确性定理

定理:对于任意由基本初等函数组成的复合函数,自动微分算法计算的导数等于解析导数(假设浮点运算精确)。

证明概要

  1. 基本初等函数的导数定义正确
  2. 链式法则适用于雅可比矩阵乘法
  3. 归纳法证明整个复合函数的导数正确

8.2 数值稳定性

自动微分使用与前向计算相同的浮点运算,因此:

  • 无截断误差(与符号微分相同)
  • 存在舍入误差(但与前向计算量级相同)
  • 精度受限于浮点表示(通常为 数量级)

8.3 梯度校验

使用数值微分验证自动微分实现的正确性:

def gradient_check(f, x, eps=1e-7):
    """使用中心差分验证梯度"""
    grad_numerical = []
    grad_autodiff = jax.grad(f)(x)
    
    for i in range(len(x)):
        x_plus = x.at[i].add(eps)
        x_minus = x.at[i].sub(eps)
        grad_numerical.append((f(x_plus) - f(x_minus)) / (2 * eps))
    
    return jnp.allclose(grad_numerical, grad_autodiff, rtol=1e-5)

9. 总结

9.1 核心要点

  1. 自动微分通过链式法则将复合函数分解为基本运算,实现精确高效的导数计算
  2. 正向模式适用于输入维度小、输出维度大的场景
  3. 反向模式适用于输入维度大(神经网络参数)、输出维度小(单损失值)的场景
  4. 对偶数提供了一种优雅的正向模式实现方式
  5. 自动微分的精度是机器精度,不存在数值微分的截断误差

9.2 后续内容

参考资料