高阶导数理论

高阶导数在许多机器学习应用中至关重要,包括牛顿优化、MAML元学习、神经网络曲率分析等。然而,高阶导数的计算面临独特的挑战。

1. 高阶导数的挑战

1.1 朴素嵌套微分的指数爆炸

直接嵌套自动微分会导致计算复杂度指数增长:

# 计算k阶导数
kth_deriv = f
for _ in range(k):
    kth_deriv = grad(kth_deriv)
 
# 复杂度分析:
# - n阶导数需要O(n)次梯度计算
# - 每次计算与前向计算成本相当
# - 总复杂度: O(n²) 相对于前向计算

问题:对于深度嵌套(如50层网络),计算量变得不可行。

1.2 Faà di Bruno公式

Faà di Bruno公式描述了复合函数的高阶导数:

定理:设 ,则:

其中 Bell多项式

1.3 Bell多项式定义

第一类Bell多项式

其中求和遍历所有满足 的非负整数序列。

1.4 Faà di Bruno的简化形式

时,公式简化为:

示例:计算

2. Taylor模式自动微分

2.1 核心思想

Taylor模式自动微分通过追踪Taylor系数而非单个导数来高效计算高阶导数:

def taylor_jet(f, x0, d):
    """
    计算f在x0处的Taylor展开到阶d
    返回 (f(x0), f'(x0), f''(x0), ..., f^(d)(x0))
    """
    # Jet操作:同时计算所有阶的导数
    pass

2.2 Jet操作的数学形式

设输入的Taylor展开为:

函数 的输出为:

其中

2.3 链式法则的Taylor形式

对于复合

这可以通过动态规划在 时间内计算所有

2.4 Taylor模式实现

import jax.numpy as jnp
from jax import jit
 
def taylor_expand(f, x0, d):
    """使用JAX计算Taylor系数"""
    from jax import jacfwd, jacrev
    
    # 准备输入的多项式表示
    # [x0, 1, 0, 0, ...] 表示 x(t) = x0 + 1*t
    coeffs = [x0] + [jnp.zeros_like(x0) for _ in range(d)]
    coeffs[1] = jnp.ones_like(x0)
    
    # 使用jacfwd计算高阶项
    # 这实际上计算了所有阶的导数
    def compute_series(x):
        from jax import hessian, jacobian
        
        results = [f(x)]
        fp = jacfwd(f)(x)  # 一阶
        results.append(fp)
        
        # 高阶项通过迭代计算
        for i in range(2, d+1):
            # 这里需要更复杂的实现
            pass
        
        return results
    
    return compute_series(x0)

3. Hessian矩阵计算

3.1 Hessian矩阵定义

对于函数 ,Hessian矩阵为:

3.2 JAX中的Hessian计算

import jax.numpy as jnp
from jax import grad, jacfwd, jacrev, hessian
 
def f(x):
    return jnp.sum(x ** 3) + jnp.prod(x)
 
x = jnp.array([1.0, 2.0, 3.0])
 
# 方法1:嵌套梯度
h1 = jacfwd(grad(f))(x)
 
# 方法2:使用hessian(推荐)
h2 = hessian(f)(x)
 
# 验证两者等价
assert jnp.allclose(h1, h2)
print(f"Hessian:\n{h2}")

3.3 有效Hessian计算

完整Hessian矩阵的存储和计算都是 ,但许多应用只需要Hessian-向量积

from jax import jvp
 
def hvp(f, x, v):
    """
    计算 Hessian @ v = ∇²f(x) @ v
    复杂度:O(n) 而非 O(n²)
    """
    return jvp(grad(f), (x,), (v,))[1]
 
x = jnp.array([1.0, 2.0, 3.0])
v = jnp.array([1.0, 0.0, 0.0])
 
hvp_val = hvp(lambda x: jnp.sum(x**3), x, v)
print(f"Hessian @ v: {hvp_val}")

3.4 Krylov子空间方法

Hessian-向量积使得无需显式计算Hessian即可进行优化:

def conjugate_gradient(H_func, b, x0, max_iter=10):
    """
    共轭梯度法求解 Hx = b
    仅使用Hessian-向量积
    """
    r = b - H_func(x0) @ x0
    p = r.copy()
    rsold = jnp.dot(r, r)
    
    for _ in range(max_iter):
        Ap = H_func(x0) @ p
        alpha = rsold / (jnp.dot(p, Ap) + 1e-10)
        x0 = x0 + alpha * p
        r = r - alpha * Ap
        rsnew = jnp.dot(r, r)
        
        if jnp.sqrt(rsnew) < 1e-8:
            break
        
        beta = rsnew / rsold
        p = r + beta * p
        rsold = rsnew
    
    return x0

4. 应用场景

4.1 牛顿法优化

牛顿法使用Hessian进行二阶优化:

def newton_step(f, x, damping=1e-5):
    """牛顿法一步迭代"""
    grad_f = grad(f)(x)
    hvp_f = lambda v: hvp(f, x, v)
    
    # 使用共轭梯度求解 Hx = -g
    h_inv_g = conjugate_gradient(lambda v: hvp_f(v), -grad_f, x)
    
    return x + h_inv_g + damping * grad_f

4.2 MAML元学习

MAML需要梯度对梯度的计算:

def maml_inner_step(theta, alpha=0.01):
    """MAML内循环"""
    def loss_fn(params):
        return compute_train_loss(params)
    
    # 一阶梯度
    grads = grad(loss_fn)(theta)
    
    # 梯度步进后的参数
    theta_star = theta - alpha * grads
    
    # 外循环:关于theta的梯度
    def outer_loss(theta):
        return compute_val_loss(theta - alpha * grad(loss_fn)(theta))
    
    outer_grads = grad(outer_loss)(theta)
    
    return outer_grads

4.3 曲率分析

神经网络训练动力学的研究需要曲率信息:

def compute_curvature(model, x, y):
    """计算损失景观的曲率"""
    def loss(params):
        return compute_loss(model, params, x, y)
    
    # Hessian的特征值分解
    H = hessian(loss)(model.parameters())
    
    # 特征值表示曲率方向
    eigvals, eigvecs = jnp.linalg.eigh(H)
    
    return {
        'positive_ratio': jnp.mean(eigvals > 0),
        'condition_number': jnp.max(jnp.abs(eigvals)) / (jnp.min(jnp.abs(eigvals)) + 1e-10),
        'top_curvature': eigvals[-1],
        'bottom_curvature': eigvals[0]
    }

4.4 神经ODE与伴随方法

神经ODE使用高阶导数来求解常微分方程:

from jax.experimental.ode import odeint
 
def neural_ode(params, x0, t):
    """神经ODE前向传播"""
    def dxdt(x, t):
        return mlp(x, params)
    
    return odeint(dxdt, x0, t)
 
# 伴随方法计算ODE梯度(避免存储完整轨迹)
from jax.adjoint import odeint as adj_odeint
 
def neural_ode_gradient(params, x0, t, y_target):
    """使用伴随方法的梯度计算"""
    def loss(params):
        x = adj_odeint(lambda x, t: mlp(x, params), x0, t)
        return jnp.mean((x - y_target) ** 2)
    
    return grad(loss)(params)

4.5 物理信息网络(PINN)

PINN中高阶导数用于编码物理定律:

def pinn_loss(params, x, t):
    """PINN损失:数据 + PDE约束"""
    u = pinn_model(params, x, t)
    
    # 自动计算高阶导数
    u_t = grad(lambda x: pinn_model(params, x, t[:, 1]))(x)
    u_xx = grad(grad(lambda x: pinn_model(params, x, t[:, 1])))(x)
    
    # 热方程 PDE: u_t = α * u_xx
    pde_residual = u_t - alpha * u_xx
    
    return jnp.mean((u - u_data) ** 2) + jnp.mean(pde_residual ** 2)

5. JAX高阶导数实现

5.1 多重嵌套grad

# k阶导数
def kth_grad(f, k):
    def fn(x):
        result = f
        for _ in range(k):
            result = grad(result)
        return result
    return fn
 
# 计算三阶导数
f3 = kth_grad(lambda x: jnp.sin(x) + x**2, 3)
print(f3(jnp.array(0.5)))  # -cos(0.5) + 0

5.2 jacfwd与jacrev组合

# Hessian = jacfwd(jacrev(f)) 或 jacfwd(grad(f))
H1 = jacfwd(jacrev(f))(x)
H2 = jacfwd(grad(f))(x)
 
# 选择依据:
# - jacfwd(jacrev(...)): 先反向后正向,适合参数少、输出多
# - jacfwd(grad(...)): 直接使用grad,适合标量函数

5.3 高效的梯度-梯度积

def grad_grad(f):
    """计算二阶梯度的函数"""
    def fn(x):
        return jacfwd(grad(f))(x)
    return fn
 
# 使用
H = grad_grad(lambda x: jnp.sum(x**3))(jnp.array([1., 2., 3.]))
print(H)
# [[6. 0. 0.]
#  [0. 12. 0.]
#  [0. 0. 18.]]

6. 数值稳定性

6.1 数值溢出

高阶导数容易数值溢出:

import jax.numpy as jnp
from jax import grad
 
def f(x):
    return jnp.exp(x)
 
# 高阶导数:e^x快速增长
for k in range(1, 10):
    deriv = grad(f, k) if k > 1 else grad(f)
    x = 10.0
    try:
        val = deriv(x)
        print(f"{k}阶导数: {val}")
    except Exception as e:
        print(f"{k}阶导数: 溢出!")

6.2 阻尼技巧

def stable_hessian(f, x, damping=1e-5):
    """带阻尼的稳定Hessian"""
    H = hessian(f)(x)
    
    # 添加阻尼确保正定
    d = H.shape[0]
    return H + damping * jnp.eye(d)

6.3 尺度归一化

def normalized_hessian(f, x):
    """归一化Hessian避免数值问题"""
    H = hessian(f)(x)
    
    # 计算Hessian的范数
    H_norm = jnp.linalg.norm(H)
    
    # 避免数值溢出
    return H / (H_norm + 1e-10)

7. 总结

7.1 高阶导数核心要点

  1. 指数爆炸问题:朴素嵌套微分导致复杂度指数增长
  2. Faà di Bruno公式:复合函数高阶导数的闭合形式
  3. Taylor模式:通过追踪Taylor系数高效计算
  4. Hessian-向量积 复杂度避免 存储
  5. 广泛应用:牛顿优化、MAML、神经ODE、PINN

7.2 相关专题

参考资料