计算图表示与执行

计算图(Computational Graph)是自动微分的核心数据结构,它将任意复合函数表示为有向无环图(DAG),使得导数的自动计算成为可能。

1. 计算图基础

1.1 图的数学定义

有向无环图(DAG)

  • :节点集合(Variables)
  • :有向边集合

图的基本性质

  • 无环性:不存在从节点返回自身的路径
  • 拓扑序:存在节点排序使得所有边指向正方向

1.2 节点类型

节点类型描述示例
叶子节点无前驱,代表输入或参数x, W, b
操作节点有前驱和后继,代表基本运算matmul, relu
根节点无后继,代表最终输出loss

1.3 边类型

边类型方向携带信息
前向边父节点 → 子节点数据值
反向边子节点 → 父节点梯度值

2. 图构建过程

2.1 动态图构建(PyTorch风格)

PyTorch使用动态图:每次前向传播时实时构建计算图。

import torch
 
# 每次执行都重新构建图
x = torch.tensor([1.0], requires_grad=True)
w = torch.tensor([2.0], requires_grad=True)
b = torch.tensor([3.0], requires_grad=True)
 
# 图构建过程
z = w * x      # 创建乘法节点
z = z + b      # 创建加法节点
y = torch.relu(z)  # 创建ReLU节点
 
# 图结构
"""
       x         w         b
        \       /          /
         \     /          /
          v   v          v
           matmul  →  add  →  relu  →  y
"""

2.2 静态图构建(TensorFlow/JAX风格)

静态图在执行前完成构建和优化:

# JAX风格
import jax
 
def f(x, w, b):
    z = w * x
    z = z + b
    y = jax.nn.relu(z)
    return y
 
# JIT编译时构建图
f_jit = jax.jit(f)
# 此时构建计算图并优化

2.3 追踪模式

追踪(Tracing)记录实际执行的计算:

def trace_function(f, *args):
    """简化追踪实现"""
    tape = []
    
    class TapeContext:
        def record(self, op, inputs, output):
            tape.append((op, inputs, output))
    
    ctx = TapeContext()
    output = f(*args)  # 实际执行
    return output, tape

追踪的限制

def f(x):
    if x > 0:  # 条件分支
        return x ** 2
    else:
        return x ** 3
 
# 追踪只记录实际走过的路径
# 可能丢失else分支的信息

3. 拓扑排序与执行顺序

3.1 拓扑排序算法

Kahn算法(Kahn’s Algorithm)

from collections import deque
 
def topological_sort(graph):
    """
    Kahn算法:基于入度排序
    graph: {node: [successors]}
    """
    in_degree = {node: 0 for node in graph}
    
    # 计算入度
    for node in graph:
        for succ in graph[node]:
            in_degree[succ] += 1
    
    # 入度为0的节点入队
    queue = deque([node for node in graph if in_degree[node] == 0])
    result = []
    
    while queue:
        node = queue.popleft()
        result.append(node)
        
        for succ in graph[node]:
            in_degree[succ] -= 1
            if in_degree[succ] == 0:
                queue.append(succ)
    
    return result

3.2 拓扑编号

拓扑编号(Topological Number)

def compute_topological_nr(graph, node):
    """计算节点的拓扑编号"""
    if node not in graph or not graph[node]:
        return 0
    return max(compute_topological_nr(graph, succ) + 1 
               for succ in graph[node])

优化应用

// 仅当当前节点的拓扑编号小于需要的批次号时才执行
if (node.topological_nr_ < exec_info.needed_batch_nr) {
    continue;  // 跳过不必要的计算
}

3.3 执行顺序约束

前向计算:      x1 → x2 → x3 → x4 → y
                       ↑
反向计算:      y ← ∂y/∂y ← x4 ← x3 ← x2 ← x1
              (1)        (2)   (3)   (4)   (5)

反向执行必须满足的约束

  • 只有当所有后继节点的梯度计算完成后,才能计算当前节点的梯度
  • 这正是拓扑排序的逆序保证的

4. 梯度累积机制

4.1 多路径梯度合并

当多个后继节点依赖同一父节点时:

           ┌──────────┐
    x ────►│          │
           │  Split   │
           │          ├──► x1 ──► ...
           │          │
           │          ├──► x2 ──► ...
           └──────────┘

梯度计算:

4.2 原地累积实现

class AccumulateGrad:
    def __init__(self, variable):
        self.variable = variable
        self.grad = None
    
    def accumulate(self, grad):
        if self.grad is None:
            self.grad = grad
        else:
            self.grad = self.grad + grad  # 原地累积

4.3 累积语义示例

import torch
 
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2  # 路径1
z = x ** 3  # 路径2
out = y + z  # y和z汇聚
 
# 反向传播
out.backward()
 
# 梯度: dy/dx + dz/dx = 2x + 3x² = 2 + 3 = 5
print(x.grad)  # tensor([5.])

5. 图优化

5.1 公共子表达式消除

# 优化前
a = x + y
b = x + y  # 重复计算
c = a * b
 
# 优化后(消除公共子表达式)
t = x + y
a = t
b = t
c = t * t

5.2 常量折叠

# 优化前
y = 2.0 * 3.0 * x  # 6.0 * x
 
# 优化后(常量折叠)
y = 6.0 * x

5.3 内存优化

# 使用梯度检查点减少内存
from torch.utils.checkpoint import checkpoint
 
class Model(nn.Module):
    def forward(self, x):
        # 中间激活不会被保存
        x = checkpoint(self.layer1, x)
        x = checkpoint(self.layer2, x)
        return self.layer3(x)

6. 控制流处理

6.1 条件分支

JAX的处理方式

import jax.numpy as jnp
from jax import lax
 
def f(x, pred):
    # 根据条件选择计算路径
    return lax.cond(
        pred,
        lambda _: x ** 2,      # true branch
        lambda _: x ** 3,      # false branch
        operand=None
    )
 
# JAX会为两个分支都生成梯度计算代码

6.2 循环

PyTorch的反向传播支持循环

def rnn_step(x, h, W):
    for _ in range(num_layers):
        h = jnp.tanh(jnp.dot(W, jnp.concatenate([x, h])))
    return h
 
# JAX支持在反向传播中展开循环

6.3 扫描(Scan)

高效处理序列计算

from jax import lax
 
def scan_fn(carry, x):
    return carry + x, carry * x
 
# 扫描高效执行序列计算,同时支持反向传播
final, outputs = lax.scan(scan_fn, init=0, xs=data)

7. 图的持久化与调试

7.1 图可视化

# 使用torchviz可视化
from torchviz import make_dot
 
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2 + x ** 3
make_dot(y, params={'x': x}).render("graph")

7.2 图检查点

# 调试:打印图信息
def debug_hook(grad_input, grad_output):
    print(f"grad_input: {grad_input}")
    print(f"grad_output: {grad_output}")
 
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2
y.register_hook(debug_hook)
y.backward()

7.3 图克隆

# 克隆图以进行修改而不影响原图
def clone_graph(node):
    """克隆计算图"""
    mapping = {}
    
    def clone(node):
        if node in mapping:
            return mapping[node]
        
        new_node = node.clone()
        mapping[node] = new_node
        
        for input_node in node.inputs:
            new_input = clone(input_node)
            new_node.add_input(new_input)
        
        return new_node
    
    return clone(node)

8. 高级话题

8.1 图切断

# 使用detach切断计算图
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2
z = y.detach()  # z不再连接到x
w = z ** 2
w.backward()  # x.grad = None
 
# 使用stop_gradient(语义等价)
z = torch.stop_gradient(y)

8.2 图替换

# 用优化的实现替换子图
class OptimizedConv(nn.Module):
    def forward(self, x):
        # 可以替换为更高效的融合实现
        return torch.ops.aten.conv2d(x, self.weight, self.bias)

8.3 函数式图表示

JAX使用纯函数式表示计算图:

from jax import make_jaxpr
 
def f(x, y):
    return x * y + jnp.sum(x)
 
# JAXPR表示
print(make_jaxpr(f)(jnp.array([1., 2.]), jnp.array([3., 4.])))
# { lambda ; a:i32[2] b:i32[2].
#   let c:f32[2] = mul a b
#   let d:f32 = reduce_sum[ axes=(0,) ] c
#   let e:f32[2] = add c d
#   let f:f32[2] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(2,) ] d
#   let g:f32[2] = add e f
#   let h:f32[2] = add g a
#   let i:f32[2] = mul h b
#   let j:f32 = reduce_sum[ axes=(0,) ] i
#   let k:f32[2] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(2,) ] j
#   let l:f32[2] = add i k
#   result l }

9. 总结

9.1 计算图核心要点

  1. DAG结构:无环有向图是自动微分的基础
  2. 动态vs静态:PyTorch动态图 vs JAX/TF静态图
  3. 拓扑排序:保证前向和反向的正确执行顺序
  4. 梯度累积:多路径梯度合并机制
  5. 图优化:公共子表达式消除、常量折叠等

9.2 相关专题

参考资料