JAX自动微分框架
JAX是Google开发的下一代高性能数值计算框架,其核心设计基于函数式编程范式和组合式变换(Composable Transformations)。本专题深入分析JAX的自动微分系统。
1. JAX设计理念
1.1 函数式核心
JAX的核心理念是将所有计算表示为纯函数:
import jax.numpy as jnp
from jax import grad
# 纯函数:相同输入产生相同输出,无副作用
def loss(params, x, y):
pred = jnp.dot(x, params['w']) + params['b']
return jnp.mean((pred - y) ** 2)
# grad创建梯度函数
grad_loss = grad(loss)
# 相同输入产生相同梯度
params = {'w': jnp.array([1.0, 2.0]), 'b': jnp.array(0.0)}
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
y = jnp.array([1.0, 2.0])
grads = grad_loss(params, x, y)1.2 变换体系架构
┌─────────────────────────────────────────────────────┐
│ JAX Transformations │
├─────────────────────────────────────────────────────┤
│ │
│ jit ──► 编译加速(XLA) │
│ │ │
│ ├── grad ──► 自动微分 │
│ │ │ │
│ │ └── jacfwd ──► 正向模式雅可比 │
│ │ │ │
│ │ └── jacrev ──► 反向模式雅可比 │
│ │ │ │
│ │ └── hessian ──► Hessian矩阵 │
│ │ │
│ ├── vmap ──► 自动向量化 │
│ │ │
│ └── pmap ──► 数据并行 │
│ │
└─────────────────────────────────────────────────────┘
1.3 变换组合
JAX的强大之处在于变换可以自由组合:
from jax import grad, jit, vmap
# 组合:编译 + 向量化 + 梯度
@jit
def batch_grad_loss(params, X, Y):
return vmap(grad(loss), in_axes=(None, 0, 0))(params, X, Y)
# 对整个batch同时计算梯度
grads = batch_grad_loss(params, X_batch, Y_batch)2. 核心微分变换
2.1 grad:标量梯度
from jax import grad
def f(x):
return jnp.sum(x ** 2)
# 计算 ∇f(x)
grad_f = grad(f)
x = jnp.array([1.0, 2.0, 3.0])
print(grad_f(x)) # DeviceArray([2., 4., 6.])
# 数学上:df/dx_i = 2x_i默认行为:
- 返回第一个参数(按位置)的梯度
- 只适用于标量输出函数
2.2 多参数微分
def f(w, b, x, y):
pred = jnp.dot(x, w) + b
return jnp.mean((pred - y) ** 2)
# 只对w求梯度
grad_w = grad(f, argnums=0)
# 对w和b同时求梯度
grad_wb = grad(f, argnums=(0, 1))
# 对所有参数求梯度
grad_all = grad(f, argnums=(0, 1, 2, 3))2.3 value_and_grad
from jax import value_and_grad
# 同时返回函数值和梯度
val_and_grad = value_and_grad(f)
loss_value, grads = val_and_grad(w, b, x, y)
print(f"loss: {loss_value}, grad_w norm: {jnp.linalg.norm(grads['w'])}")3. 雅可比矩阵计算
3.1 jacfwd:正向模式雅可比
from jax import jacfwd
def f(x):
return jnp.array([x[0] ** 2, jnp.sin(x[1]), x[0] * x[1]])
# 计算雅可比矩阵(每行是输出对输入的梯度)
J = jacfwd(f)(jnp.array([1.0, jnp.pi / 2]))
print(J)
# DeviceArray([[2. , 0. ],
# [0. , -1. ],
# [1.57079633, 1. ]], dtype=float32)数学形式:
3.2 jacrev:反向模式雅可比
from jax import jacrev
# 反向模式计算雅可比
J = jacrev(f)(jnp.array([1.0, jnp.pi / 2]))3.3 模式选择
| 特性 | jacfwd(正向) | jacrev(反向) |
|---|---|---|
| 计算方向 | 沿输入维度 | 沿输出维度 |
| 适用场景 | 输出多、输入少 | 输入多、输出少 |
| 内存效率 | 适合长计算图 | 适合宽计算图 |
| 典型应用 | 神经网络层雅可比 | 损失函数梯度 |
# JAX自动选择最优模式
# 当输出维度 > 输入维度时,使用jacfwd
# 当输入维度 > 输出维度时,使用jacrev
# 也可以显式指定
from jax import jacobian # 自动选择
from jax import jacfwd as jacobian_forward
from jax import jacrev as jacobian_reverse4. Hessian矩阵计算
4.1 基本用法
from jax import grad, jacfwd, jacrev
def f(x):
return jnp.sum(x ** 3)
# 方法1:嵌套grad
grad_grad_f = grad(grad(f))
H1 = jacfwd(grad(f))(x)
# 方法2:使用hessian(推荐)
from jax import hessian
H = hessian(f)(x)
print(H)
# 二阶导数:d²/dx_i dx_j (x³) = 6x_i * δ_ij
# 对角元素:6x_i
# 非对角元素:04.2 Hessian-向量积
from jax import jvp
def hvp(f, x, v):
"""Hessian-向量积:H(x) @ v"""
return jvp(grad(f), (x,), (v,))[1]
# 更高效:不需要显式计算完整Hessian
x = jnp.array([1.0, 2.0, 3.0])
v = jnp.array([1.0, 0.0, 0.0])
hvp_val = hvp(lambda x: jnp.sum(x**3), x, v)
print(hvp_val) # DeviceArray([6., 0., 0.], dtype=float32)5. JAXPR:JAX的程序表示
5.1 什么是JAXPR
JAXPR( JAX Program Representation)是JAX的中间表示,描述计算图:
from jax import make_jaxpr
def f(x, y):
return x * y + jnp.sum(x)
# 生成JAXPR
jaxpr = make_jaxpr(f)(jnp.array([1., 2.]), jnp.array([3., 4.]))
print(jaxpr)输出:
{ lambda ; a:f32[2] b:f32[2].
let c:f32[2] = mul a b
let d:f32 = reduce_sum[ axes=(0,) ] c
let e:f32[2] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(2,) ] d
let f:f32[2] = add c e
let g:f32[2] = add f a
result g
}
5.2 JAXPR结构
JAXPR
├── literals: [] # 常量
├── in_shapes: [...] # 输入形状
├── out_shapes: [...] # 输出形状
└── eqns: [ # 方程列表
{
eqn_name: "mul",
operands: [a, b],
primitive: "mul",
output: c
},
...
]
5.3 追踪与雾化
追踪(Traced):
- 实际执行时,用具体值替换参数
- JAXPR记录执行的操作
雾化(Opaque):
- 在
jit编译时,输入被”雾化”(用抽象值替换) - 只保留操作的结构信息
@jit
def f(x):
return x ** 2 + jnp.sin(x)
# 即使在jit内,JAX也追踪计算
# 可以通过make_jaxpr查看
print(make_jaxpr(f)(jnp.array(1.0)))6. 高级微分功能
6.1 条件微分
from jax import lax
def f(x, cond):
return lax.cond(
cond,
lambda _: x ** 2, # true branch
lambda _: x ** 3, # false branch
operand=None
)
# JAX支持对条件分支进行微分
grad_f = grad(f)
x = 2.0
print(grad_f(x, True)) # 4.0 (d/dx x² = 2x = 4)
print(grad_f(x, False)) # 12.0 (d/dx x³ = 3x² = 12)6.2 循环微分
from jax import lax
def scan_fn(carry, x):
"""扫描函数:用于循环"""
new_carry = carry + x * 2
return new_carry, carry * x
# lax.scan高效处理循环
final, outputs = lax.scan(scan_fn, init=0.0, xs=jnp.arange(5))
# 自动支持反向传播
grad_scan = grad(lambda params: lax.scan(
lambda c, x: (c + params * x, params * x),
0.0, jnp.arange(5)
)[1].sum())6.3 PyTree支持
from jax import tree_util
params = {
'dense1': {'w': jnp.zeros((10, 20)), 'b': jnp.zeros(10)},
'dense2': {'w': jnp.zeros((20, 5)), 'b': jnp.zeros(5)}
}
# grad支持PyTree
grad_fn = grad(loss)
# 返回相同结构的梯度
grads = grad_fn(params, x, y)
print(tree_util.tree_map(jnp.shape, grads))
# {'dense1': {'w': (10, 20), 'b': (10,)},
# 'dense2': {'w': (20, 5), 'b': (5,)}}7. JAX vs PyTorch对比
7.1 编程范式
| 特性 | JAX | PyTorch |
|---|---|---|
| 范式 | 函数式 | 命令式 |
| 计算图 | 静态(jit内) | 动态 |
| 状态管理 | 无(纯函数) | 有(nn.Module) |
| 更新参数 | 返回新数组 | 原地修改 |
7.2 API对比
# PyTorch
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2 + torch.sin(x)
y.backward()
# JAX
x = jnp.array([1.0])
y = x ** 2 + jnp.sin(x)
grad_y = grad(lambda x: x ** 2 + jnp.sin(x))(x)7.3 性能对比
# JAX:jit编译后极快
@jit
def train_step(params, batch):
def loss_fn(p):
return compute_loss(p, batch)
grads = grad(loss_fn)(params)
return optax.apply_updates(params, grads)
# PyTorch:eager模式灵活
for batch in dataloader:
optimizer.zero_grad()
loss = compute_loss(model, batch)
loss.backward()
optimizer.step()7.4 内存模型
| 特性 | JAX | PyTorch |
|---|---|---|
| 内存分配 | 函数式(不可变) | 命令式(可变) |
| 原地操作 | 不支持 | 支持 |
| 内存效率 | 需手动管理 | 自动管理 |
8. 高级用法
8.1 自定义梯度
from jax import custom_grad
@custom_grad
def my_relu(x):
"""自定义ReLU及其梯度"""
def grad_fn(grad_output):
return grad_output * (x > 0).astype(x.dtype)
return jnp.maximum(0, x), grad_fn
# 使用自定义梯度
x = jnp.array([-1.0, 2.0, -3.0])
y = my_relu(x)
dy = grad(my_relu)(x)
print(f"y: {y}, dy: {dy}")
# y: [0. 2. 0.], dy: [0. 1. 0.]8.2 阻尼/稳定化
def stable_grad(f, eps=1e-8):
"""添加数值稳定性的梯度函数"""
def grad_fn(*args):
return grad(f)(*args)
return grad_fn
# 在Hessian计算中添加阻尼
def hessian_with_damping(f, damping=1e-5):
def hess_fn(x):
return hessian(f)(x) + damping * jnp.eye(len(x))
return hess_fn8.3 梯度检查
from jax import grad
from jax.test_util import check_grads
def f(x, y):
return jnp.sum(x * y)
# 数值梯度校验
check_grads(f, args=(jnp.array([1., 2.]), jnp.array([3., 4.])),
order=2) # 检查到二阶导数9. 实用技巧
9.1 JIT编译与微分
@jit
def compiled_grad_loss(params, x, y):
return grad(loss)(params, x, y)
# 预热
compiled_grad_loss(params, x, y)
# 实际使用
for batch in dataloader:
grads = compiled_grad_loss(params, batch.x, batch.y)
params = optax.apply_updates(params, grads)9.2 vmap + grad组合
# 一次性对整个batch计算梯度
batched_grad_fn = vmap(grad(loss), in_axes=(None, 0, 0))
grads = batched_grad_fn(params, X_batch, Y_batch)
# grads['w'] 形状: (batch_size, input_dim, output_dim)9.3 性能优化
# 1. 避免在jit内创建新数组
@jit
def bad_func(x):
return x + jnp.array([1.0]) # 每次创建新数组
@jit
def good_func(x):
one = jnp.ones_like(x) # 在jit外创建
return x + one
# 2. 使用静态参数
@partial(jit, static_argnums=(2,))
def func(x, y, dim):
return jnp.sum(x, axis=dim)10. 总结
10.1 JAX核心要点
- 函数式范式:纯函数、无副作用
- 组合变换:jit、grad、vmap可自由组合
- JAXPR表示:清晰的中间表示
- 自动选择:jacfwd/jacrev自动选择最优模式
- PyTree支持:灵活处理嵌套参数结构