计算图与自动微分

自动微分(Automatic Differentiation, AD)是现代深度学习框架的核心技术,使得梯度计算变得自动化。本文档介绍计算图的概念和主流框架的实现方式。

计算图基础

什么是计算图

计算图(Computational Graph)是一种表示数学表达式的方法:

  • 节点(Node):表示变量(标量、向量、矩阵)或操作
  • 边(Edge):表示数据依赖关系

示例:

        x ──┬───────────────────┐
            │                   │
            ├──────> + ──────────┼───────> * ────> f
            │                   │
        y ──┘                   │
                                │
        x ──────────────────────┴───> sin ────┘

计算图与梯度

计算图使得我们可以系统地应用链式法则,自动计算任意复杂表达式的梯度。

自动微分的数学原理

四种微分模式对比

模式原理复杂度适用场景
数值微分调试、简单函数
符号微分解析表达式推导表达式膨胀闭式推导
前向模式AD沿输入方向传播少量输入,多输出
反向模式AD沿输出方向传播大量输入,少量输出

反向模式AD(反向传播的核心)

对于函数 (如神经网络损失函数),反向模式只需 次前向传播 + 次反向传播即可得到全部梯度的雅可比矩阵。

设计算图节点 的值为 ,我们有:

反向传播算法

  1. 从输出节点开始,初始化
  2. 按拓扑逆序遍历节点
  3. 对每个节点应用:

示例推导

计算 的梯度。

前向传播

反向传播

验证:

PyTorch autograd机制

核心概念

PyTorch使用动态计算图(Dynamic Computational Graph):

  • 图在每次前向传播时动态构建
  • torch.Tensor支持requires_grad属性追踪计算历史
  • backward()触发反向传播

基本使用

import torch
 
# 创建需要梯度的张量
x = torch.tensor([1.0, 2.0], requires_grad=True)
y = torch.tensor([3.0, 4.0], requires_grad=True)
 
# 构建计算图
z = x * y + x.sum()  # z = x*y + sum(x)
print(f"z = {z}")  # z = tensor([4., 10.], grad_fn=<AddBackward0>)
 
# 反向传播
z.sum().backward()  # 对标量调用,或使用 z.sum().backward()
 
# 查看梯度
print(f"x.grad = {x.grad}")  # x.grad = tensor([4., 5.])
print(f"y.grad = {y.grad}")  # y.grad = tensor([1., 2.])

计算图结构

# 查看计算图的详细信息
print(z.grad_fn)  # <AddBackward0 at 0x...>
print(z.grad_fn.next_functions)  # 子节点的grad_fn
 
# Function类
print(x.grad_fn)  # None (叶子节点)

非标量张量的反向传播

对于非标量输出,需要指定梯度初始值:

# z是向量(2,),不是标量
z = x * y  # z = [x0*y0, x1*y1]
 
# 需要传入相同形状的梯度
z.backward(torch.ones_like(z))  # 或 z.sum().backward()
 
# 或者使用Jacobian手动计算
from torch.autograd.functional import jacobian
J = jacobian(lambda x: x * y, x)

停止梯度追踪

# 方法1: detach()
z = x * y
z_detached = z.detach()
 
# 方法2: with torch.no_grad():
with torch.no_grad():
    z_no_grad = x * y
 
# 方法3: .requires_grad_(False)
z.requires_grad_(False)

高阶导数

# 计算二阶导数
x = torch.tensor([2.0, 3.0], requires_grad=True)
y = x ** 3
 
# 一阶导数:dy/dx = 3*x^2
grad1 = torch.autograd.grad(y.sum(), x, create_graph=True)
print(f"一阶导数: {grad1}")  # tensor([12., 27.])
 
# 二阶导数:d²y/dx² = 6*x
grad2 = torch.autograd.grad(grad1[0].sum(), x)
print(f"二阶导数: {grad2}")  # tensor([12., 18.])

训练循环中的autograd

import torch.nn as nn
 
model = nn.Sequential(
    nn.Linear(10, 64),
    nn.ReLU(),
    nn.Linear(64, 1)
)
 
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
 
for epoch in range(100):
    # 前向传播(动态构建计算图)
    X_batch, y_batch = get_batch()
    predictions = model(X_batch)
    loss = criterion(predictions, y_batch)
    
    # 反向传播
    optimizer.zero_grad()  # 清零梯度
    loss.backward()         # 计算梯度
    
    # 参数更新
    optimizer.step()        # 更新参数

JAX的自动微分

JAX vs PyTorch

特性PyTorchJAX
计算图动态惰性(Lazy)
编译可选(torch.compile)默认JIT
纯函数
不可变性

核心API

import jax
import jax.numpy as jnp
 
# 定义函数(必须纯函数,无副作用)
def sum_squares(x):
    return jnp.sum(x ** 2)
 
# 一阶导数
grad_sum_squares = jax.grad(sum_squares)
 
x = jnp.array([1.0, 2.0, 3.0])
print(f"函数值: {sum_squares(x)}")       # 14.0
print(f"梯度: {grad_sum_squares(x)}")    # [2., 4., 6.]

value_and_grad

# 同时返回函数值和梯度
def loss_fn(params):
    W, b = params
    return jnp.sum((W @ x + b - y) ** 2)
 
# 返回 (loss, grads)
loss, grads = jax.value_and_grad(loss_fn)((W, b))

JIT编译与微分

# JIT编译的函数也可以求导
@jax.jit
def forward(x, params):
    W, b = params
    return jnp.tanh(W @ x + b)
 
# 编译后的函数求导
grad_forward = jax.jit(jax.grad(forward))
 
# 或直接组合
grad_forward_jit = jax.jit(jax.grad(forward))

Jacobian和Hessian

# Jacobian: 多输出函数的梯度
J = jax.jacfwd(sum_squares)(x)
print(J)  # [2., 4., 6.]
 
# Hessian: 二阶导数矩阵
H = jax.hessian(sum_squares)(x)
# [[2., 0., 0.],
#  [0., 4., 0.],
#  [0., 0., 6.]]

Pytrees:复杂参数结构

# JAX处理嵌套参数结构
params = {
    'linear1': {'W': W1, 'b': b1},
    'linear2': {'W': W2, 'b': b2}
}
 
# grads也是相同的结构
grads = jax.grad(loss_fn)(params)
 
# 更新参数
def sgd_update(params, grads, lr):
    return jax.tree_util.map(
        lambda p, g: p - lr * g,
        params, grads
    )

计算图实现:手动构建

反向模式AD的完整实现

import numpy as np
from typing import List, Callable
 
class Value:
    """计算图中的节点"""
    def __init__(self, data, _children=(), _op=''):
        self.data = np.array(data)
        self.grad = np.zeros_like(self.data, dtype=float)
        self._backward = lambda: None
        self._prev = set(_children)
        self._op = _op
    
    def __add__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data + other.data, (self, other), '+')
        
        def _backward():
            self.grad += out.grad * 1
            other.grad += out.grad * 1
        out._backward = _backward
        return out
    
    def __mul__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data * other.data, (self, other), '*')
        
        def _backward():
            self.grad += out.grad * other.data
            other.grad += out.grad * self.data
        out._backward = _backward
        return out
    
    def tanh(self):
        out = Value(np.tanh(self.data), (self,), 'tanh')
        
        def _backward():
            self.grad += out.grad * (1 - np.tanh(self.data)**2)
        out._backward = _backward
        return out
    
    def backward(self):
        # 拓扑排序
        topo = []
        visited = set()
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build_topo(child)
                topo.append(v)
        build_topo(self)
        
        # 反向传播
        self.grad = np.ones_like(self.data)
        for node in reversed(topo):
            node._backward()
 
# 使用示例
x = Value([1.0, 2.0])
w = Value([0.5, -0.5])
b = Value([0.1])
 
z = x * w + b
a = z.tanh()
a.backward()
 
print(f"a = {a.data}")
print(f"x.grad = {x.grad}")

性能对比

计算复杂度

框架前向反向
数值微分
符号微分表达式膨胀表达式膨胀
反向模式AD

PyTorch vs JAX性能

# PyTorch
import torch
import time
 
x = torch.randn(1000, 1000, requires_grad=True)
start = time.time()
for _ in range(100):
    y = x @ x.T
    y.sum().backward()
    x.grad.zero_()
print(f"PyTorch: {time.time()-start:.2f}s")
 
# JAX
import jax
import jax.numpy as jnp
 
x = jnp.array(np.random.randn(1000, 1000))
loss_fn = lambda x: jnp.sum(x @ x.T)
grad_fn = jax.grad(loss_fn)
start = time.time()
for _ in range(100):
    grads = grad_fn(x)
print(f"JAX: {time.time()-start:.2f}s")

参考