JAX自动微分框架

JAX是Google开发的下一代高性能数值计算框架,其核心设计基于函数式编程范式组合式变换(Composable Transformations)。本专题深入分析JAX的自动微分系统。

1. JAX设计理念

1.1 函数式核心

JAX的核心理念是将所有计算表示为纯函数

import jax.numpy as jnp
from jax import grad
 
# 纯函数:相同输入产生相同输出,无副作用
def loss(params, x, y):
    pred = jnp.dot(x, params['w']) + params['b']
    return jnp.mean((pred - y) ** 2)
 
# grad创建梯度函数
grad_loss = grad(loss)
 
# 相同输入产生相同梯度
params = {'w': jnp.array([1.0, 2.0]), 'b': jnp.array(0.0)}
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
y = jnp.array([1.0, 2.0])
 
grads = grad_loss(params, x, y)

1.2 变换体系架构

┌─────────────────────────────────────────────────────┐
│                    JAX Transformations               │
├─────────────────────────────────────────────────────┤
│                                                     │
│   jit ──► 编译加速(XLA)                          │
│     │                                               │
│     ├── grad ──► 自动微分                           │
│     │        │                                      │
│     │        └── jacfwd ──► 正向模式雅可比          │
│     │        │                                      │
│     │        └── jacrev ──► 反向模式雅可比          │
│     │        │                                      │
│     │        └── hessian ──► Hessian矩阵            │
│     │                                                │
│     ├── vmap ──► 自动向量化                         │
│     │                                                │
│     └── pmap ──► 数据并行                           │
│                                                     │
└─────────────────────────────────────────────────────┘

1.3 变换组合

JAX的强大之处在于变换可以自由组合:

from jax import grad, jit, vmap
 
# 组合:编译 + 向量化 + 梯度
@jit
def batch_grad_loss(params, X, Y):
    return vmap(grad(loss), in_axes=(None, 0, 0))(params, X, Y)
 
# 对整个batch同时计算梯度
grads = batch_grad_loss(params, X_batch, Y_batch)

2. 核心微分变换

2.1 grad:标量梯度

from jax import grad
 
def f(x):
    return jnp.sum(x ** 2)
 
# 计算 ∇f(x)
grad_f = grad(f)
 
x = jnp.array([1.0, 2.0, 3.0])
print(grad_f(x))  # DeviceArray([2., 4., 6.])
 
# 数学上:df/dx_i = 2x_i

默认行为

  • 返回第一个参数(按位置)的梯度
  • 只适用于标量输出函数

2.2 多参数微分

def f(w, b, x, y):
    pred = jnp.dot(x, w) + b
    return jnp.mean((pred - y) ** 2)
 
# 只对w求梯度
grad_w = grad(f, argnums=0)
 
# 对w和b同时求梯度
grad_wb = grad(f, argnums=(0, 1))
 
# 对所有参数求梯度
grad_all = grad(f, argnums=(0, 1, 2, 3))

2.3 value_and_grad

from jax import value_and_grad
 
# 同时返回函数值和梯度
val_and_grad = value_and_grad(f)
 
loss_value, grads = val_and_grad(w, b, x, y)
print(f"loss: {loss_value}, grad_w norm: {jnp.linalg.norm(grads['w'])}")

3. 雅可比矩阵计算

3.1 jacfwd:正向模式雅可比

from jax import jacfwd
 
def f(x):
    return jnp.array([x[0] ** 2, jnp.sin(x[1]), x[0] * x[1]])
 
# 计算雅可比矩阵(每行是输出对输入的梯度)
J = jacfwd(f)(jnp.array([1.0, jnp.pi / 2]))
 
print(J)
# DeviceArray([[2.        , 0.        ],
#              [0.        , -1.        ],
#              [1.57079633, 1.        ]], dtype=float32)

数学形式

3.2 jacrev:反向模式雅可比

from jax import jacrev
 
# 反向模式计算雅可比
J = jacrev(f)(jnp.array([1.0, jnp.pi / 2]))

3.3 模式选择

特性jacfwd(正向)jacrev(反向)
计算方向沿输入维度沿输出维度
适用场景输出多、输入少输入多、输出少
内存效率适合长计算图适合宽计算图
典型应用神经网络层雅可比损失函数梯度
# JAX自动选择最优模式
# 当输出维度 > 输入维度时,使用jacfwd
# 当输入维度 > 输出维度时,使用jacrev
 
# 也可以显式指定
from jax import jacobian  # 自动选择
from jax import jacfwd as jacobian_forward
from jax import jacrev as jacobian_reverse

4. Hessian矩阵计算

4.1 基本用法

from jax import grad, jacfwd, jacrev
 
def f(x):
    return jnp.sum(x ** 3)
 
# 方法1:嵌套grad
grad_grad_f = grad(grad(f))
H1 = jacfwd(grad(f))(x)
 
# 方法2:使用hessian(推荐)
from jax import hessian
H = hessian(f)(x)
 
print(H)
# 二阶导数:d²/dx_i dx_j (x³) = 6x_i * δ_ij
# 对角元素:6x_i
# 非对角元素:0

4.2 Hessian-向量积

from jax import jvp
 
def hvp(f, x, v):
    """Hessian-向量积:H(x) @ v"""
    return jvp(grad(f), (x,), (v,))[1]
 
# 更高效:不需要显式计算完整Hessian
x = jnp.array([1.0, 2.0, 3.0])
v = jnp.array([1.0, 0.0, 0.0])
hvp_val = hvp(lambda x: jnp.sum(x**3), x, v)
print(hvp_val)  # DeviceArray([6., 0., 0.], dtype=float32)

5. JAXPR:JAX的程序表示

5.1 什么是JAXPR

JAXPR( JAX Program Representation)是JAX的中间表示,描述计算图:

from jax import make_jaxpr
 
def f(x, y):
    return x * y + jnp.sum(x)
 
# 生成JAXPR
jaxpr = make_jaxpr(f)(jnp.array([1., 2.]), jnp.array([3., 4.]))
print(jaxpr)

输出:

{ lambda ; a:f32[2] b:f32[2].
  let c:f32[2] = mul a b
  let d:f32 = reduce_sum[ axes=(0,) ] c
  let e:f32[2] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(2,) ] d
  let f:f32[2] = add c e
  let g:f32[2] = add f a
  result g
}

5.2 JAXPR结构

JAXPR
├── literals: []              # 常量
├── in_shapes: [...]          # 输入形状
├── out_shapes: [...]         # 输出形状
└── eqns: [                   # 方程列表
    {
        eqn_name: "mul",
        operands: [a, b],
        primitive: "mul",
        output: c
    },
    ...
]

5.3 追踪与雾化

追踪(Traced)

  • 实际执行时,用具体值替换参数
  • JAXPR记录执行的操作

雾化(Opaque)

  • jit编译时,输入被”雾化”(用抽象值替换)
  • 只保留操作的结构信息
@jit
def f(x):
    return x ** 2 + jnp.sin(x)
 
# 即使在jit内,JAX也追踪计算
# 可以通过make_jaxpr查看
print(make_jaxpr(f)(jnp.array(1.0)))

6. 高级微分功能

6.1 条件微分

from jax import lax
 
def f(x, cond):
    return lax.cond(
        cond,
        lambda _: x ** 2,      # true branch
        lambda _: x ** 3,      # false branch
        operand=None
    )
 
# JAX支持对条件分支进行微分
grad_f = grad(f)
x = 2.0
print(grad_f(x, True))   # 4.0 (d/dx x² = 2x = 4)
print(grad_f(x, False))  # 12.0 (d/dx x³ = 3x² = 12)

6.2 循环微分

from jax import lax
 
def scan_fn(carry, x):
    """扫描函数:用于循环"""
    new_carry = carry + x * 2
    return new_carry, carry * x
 
# lax.scan高效处理循环
final, outputs = lax.scan(scan_fn, init=0.0, xs=jnp.arange(5))
 
# 自动支持反向传播
grad_scan = grad(lambda params: lax.scan(
    lambda c, x: (c + params * x, params * x),
    0.0, jnp.arange(5)
)[1].sum())

6.3 PyTree支持

from jax import tree_util
 
params = {
    'dense1': {'w': jnp.zeros((10, 20)), 'b': jnp.zeros(10)},
    'dense2': {'w': jnp.zeros((20, 5)), 'b': jnp.zeros(5)}
}
 
# grad支持PyTree
grad_fn = grad(loss)
 
# 返回相同结构的梯度
grads = grad_fn(params, x, y)
print(tree_util.tree_map(jnp.shape, grads))
# {'dense1': {'w': (10, 20), 'b': (10,)},
#  'dense2': {'w': (20, 5), 'b': (5,)}}

7. JAX vs PyTorch对比

7.1 编程范式

特性JAXPyTorch
范式函数式命令式
计算图静态(jit内)动态
状态管理无(纯函数)有(nn.Module)
更新参数返回新数组原地修改

7.2 API对比

# PyTorch
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2 + torch.sin(x)
y.backward()
 
# JAX
x = jnp.array([1.0])
y = x ** 2 + jnp.sin(x)
grad_y = grad(lambda x: x ** 2 + jnp.sin(x))(x)

7.3 性能对比

# JAX:jit编译后极快
@jit
def train_step(params, batch):
    def loss_fn(p):
        return compute_loss(p, batch)
    
    grads = grad(loss_fn)(params)
    return optax.apply_updates(params, grads)
 
# PyTorch:eager模式灵活
for batch in dataloader:
    optimizer.zero_grad()
    loss = compute_loss(model, batch)
    loss.backward()
    optimizer.step()

7.4 内存模型

特性JAXPyTorch
内存分配函数式(不可变)命令式(可变)
原地操作不支持支持
内存效率需手动管理自动管理

8. 高级用法

8.1 自定义梯度

from jax import custom_grad
 
@custom_grad
def my_relu(x):
    """自定义ReLU及其梯度"""
    def grad_fn(grad_output):
        return grad_output * (x > 0).astype(x.dtype)
    return jnp.maximum(0, x), grad_fn
 
# 使用自定义梯度
x = jnp.array([-1.0, 2.0, -3.0])
y = my_relu(x)
dy = grad(my_relu)(x)
print(f"y: {y}, dy: {dy}")
# y: [0. 2. 0.], dy: [0. 1. 0.]

8.2 阻尼/稳定化

def stable_grad(f, eps=1e-8):
    """添加数值稳定性的梯度函数"""
    def grad_fn(*args):
        return grad(f)(*args)
    return grad_fn
 
# 在Hessian计算中添加阻尼
def hessian_with_damping(f, damping=1e-5):
    def hess_fn(x):
        return hessian(f)(x) + damping * jnp.eye(len(x))
    return hess_fn

8.3 梯度检查

from jax import grad
from jax.test_util import check_grads
 
def f(x, y):
    return jnp.sum(x * y)
 
# 数值梯度校验
check_grads(f, args=(jnp.array([1., 2.]), jnp.array([3., 4.])), 
            order=2)  # 检查到二阶导数

9. 实用技巧

9.1 JIT编译与微分

@jit
def compiled_grad_loss(params, x, y):
    return grad(loss)(params, x, y)
 
# 预热
compiled_grad_loss(params, x, y)
 
# 实际使用
for batch in dataloader:
    grads = compiled_grad_loss(params, batch.x, batch.y)
    params = optax.apply_updates(params, grads)

9.2 vmap + grad组合

# 一次性对整个batch计算梯度
batched_grad_fn = vmap(grad(loss), in_axes=(None, 0, 0))
 
grads = batched_grad_fn(params, X_batch, Y_batch)
# grads['w'] 形状: (batch_size, input_dim, output_dim)

9.3 性能优化

# 1. 避免在jit内创建新数组
@jit
def bad_func(x):
    return x + jnp.array([1.0])  # 每次创建新数组
 
@jit
def good_func(x):
    one = jnp.ones_like(x)  # 在jit外创建
    return x + one
 
# 2. 使用静态参数
@partial(jit, static_argnums=(2,))
def func(x, y, dim):
    return jnp.sum(x, axis=dim)

10. 总结

10.1 JAX核心要点

  1. 函数式范式:纯函数、无副作用
  2. 组合变换:jit、grad、vmap可自由组合
  3. JAXPR表示:清晰的中间表示
  4. 自动选择:jacfwd/jacrev自动选择最优模式
  5. PyTree支持:灵活处理嵌套参数结构

10.2 相关专题

参考资料