计算图表示与执行
计算图(Computational Graph)是自动微分的核心数据结构,它将任意复合函数表示为有向无环图(DAG),使得导数的自动计算成为可能。
1. 计算图基础
1.1 图的数学定义
有向无环图(DAG):
- :节点集合(Variables)
- :有向边集合
图的基本性质:
- 无环性:不存在从节点返回自身的路径
- 拓扑序:存在节点排序使得所有边指向正方向
1.2 节点类型
| 节点类型 | 描述 | 示例 |
|---|---|---|
| 叶子节点 | 无前驱,代表输入或参数 | x, W, b |
| 操作节点 | 有前驱和后继,代表基本运算 | matmul, relu |
| 根节点 | 无后继,代表最终输出 | loss |
1.3 边类型
| 边类型 | 方向 | 携带信息 |
|---|---|---|
| 前向边 | 父节点 → 子节点 | 数据值 |
| 反向边 | 子节点 → 父节点 | 梯度值 |
2. 图构建过程
2.1 动态图构建(PyTorch风格)
PyTorch使用动态图:每次前向传播时实时构建计算图。
import torch
# 每次执行都重新构建图
x = torch.tensor([1.0], requires_grad=True)
w = torch.tensor([2.0], requires_grad=True)
b = torch.tensor([3.0], requires_grad=True)
# 图构建过程
z = w * x # 创建乘法节点
z = z + b # 创建加法节点
y = torch.relu(z) # 创建ReLU节点
# 图结构
"""
x w b
\ / /
\ / /
v v v
matmul → add → relu → y
"""2.2 静态图构建(TensorFlow/JAX风格)
静态图在执行前完成构建和优化:
# JAX风格
import jax
def f(x, w, b):
z = w * x
z = z + b
y = jax.nn.relu(z)
return y
# JIT编译时构建图
f_jit = jax.jit(f)
# 此时构建计算图并优化2.3 追踪模式
追踪(Tracing)记录实际执行的计算:
def trace_function(f, *args):
"""简化追踪实现"""
tape = []
class TapeContext:
def record(self, op, inputs, output):
tape.append((op, inputs, output))
ctx = TapeContext()
output = f(*args) # 实际执行
return output, tape追踪的限制:
def f(x):
if x > 0: # 条件分支
return x ** 2
else:
return x ** 3
# 追踪只记录实际走过的路径
# 可能丢失else分支的信息3. 拓扑排序与执行顺序
3.1 拓扑排序算法
Kahn算法(Kahn’s Algorithm):
from collections import deque
def topological_sort(graph):
"""
Kahn算法:基于入度排序
graph: {node: [successors]}
"""
in_degree = {node: 0 for node in graph}
# 计算入度
for node in graph:
for succ in graph[node]:
in_degree[succ] += 1
# 入度为0的节点入队
queue = deque([node for node in graph if in_degree[node] == 0])
result = []
while queue:
node = queue.popleft()
result.append(node)
for succ in graph[node]:
in_degree[succ] -= 1
if in_degree[succ] == 0:
queue.append(succ)
return result3.2 拓扑编号
拓扑编号(Topological Number):
def compute_topological_nr(graph, node):
"""计算节点的拓扑编号"""
if node not in graph or not graph[node]:
return 0
return max(compute_topological_nr(graph, succ) + 1
for succ in graph[node])优化应用:
// 仅当当前节点的拓扑编号小于需要的批次号时才执行
if (node.topological_nr_ < exec_info.needed_batch_nr) {
continue; // 跳过不必要的计算
}3.3 执行顺序约束
前向计算: x1 → x2 → x3 → x4 → y
↑
反向计算: y ← ∂y/∂y ← x4 ← x3 ← x2 ← x1
(1) (2) (3) (4) (5)
反向执行必须满足的约束:
- 只有当所有后继节点的梯度计算完成后,才能计算当前节点的梯度
- 这正是拓扑排序的逆序保证的
4. 梯度累积机制
4.1 多路径梯度合并
当多个后继节点依赖同一父节点时:
┌──────────┐
x ────►│ │
│ Split │
│ ├──► x1 ──► ...
│ │
│ ├──► x2 ──► ...
└──────────┘
梯度计算:
4.2 原地累积实现
class AccumulateGrad:
def __init__(self, variable):
self.variable = variable
self.grad = None
def accumulate(self, grad):
if self.grad is None:
self.grad = grad
else:
self.grad = self.grad + grad # 原地累积4.3 累积语义示例
import torch
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2 # 路径1
z = x ** 3 # 路径2
out = y + z # y和z汇聚
# 反向传播
out.backward()
# 梯度: dy/dx + dz/dx = 2x + 3x² = 2 + 3 = 5
print(x.grad) # tensor([5.])5. 图优化
5.1 公共子表达式消除
# 优化前
a = x + y
b = x + y # 重复计算
c = a * b
# 优化后(消除公共子表达式)
t = x + y
a = t
b = t
c = t * t5.2 常量折叠
# 优化前
y = 2.0 * 3.0 * x # 6.0 * x
# 优化后(常量折叠)
y = 6.0 * x5.3 内存优化
# 使用梯度检查点减少内存
from torch.utils.checkpoint import checkpoint
class Model(nn.Module):
def forward(self, x):
# 中间激活不会被保存
x = checkpoint(self.layer1, x)
x = checkpoint(self.layer2, x)
return self.layer3(x)6. 控制流处理
6.1 条件分支
JAX的处理方式:
import jax.numpy as jnp
from jax import lax
def f(x, pred):
# 根据条件选择计算路径
return lax.cond(
pred,
lambda _: x ** 2, # true branch
lambda _: x ** 3, # false branch
operand=None
)
# JAX会为两个分支都生成梯度计算代码6.2 循环
PyTorch的反向传播支持循环:
def rnn_step(x, h, W):
for _ in range(num_layers):
h = jnp.tanh(jnp.dot(W, jnp.concatenate([x, h])))
return h
# JAX支持在反向传播中展开循环6.3 扫描(Scan)
高效处理序列计算:
from jax import lax
def scan_fn(carry, x):
return carry + x, carry * x
# 扫描高效执行序列计算,同时支持反向传播
final, outputs = lax.scan(scan_fn, init=0, xs=data)7. 图的持久化与调试
7.1 图可视化
# 使用torchviz可视化
from torchviz import make_dot
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2 + x ** 3
make_dot(y, params={'x': x}).render("graph")7.2 图检查点
# 调试:打印图信息
def debug_hook(grad_input, grad_output):
print(f"grad_input: {grad_input}")
print(f"grad_output: {grad_output}")
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2
y.register_hook(debug_hook)
y.backward()7.3 图克隆
# 克隆图以进行修改而不影响原图
def clone_graph(node):
"""克隆计算图"""
mapping = {}
def clone(node):
if node in mapping:
return mapping[node]
new_node = node.clone()
mapping[node] = new_node
for input_node in node.inputs:
new_input = clone(input_node)
new_node.add_input(new_input)
return new_node
return clone(node)8. 高级话题
8.1 图切断
# 使用detach切断计算图
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2
z = y.detach() # z不再连接到x
w = z ** 2
w.backward() # x.grad = None
# 使用stop_gradient(语义等价)
z = torch.stop_gradient(y)8.2 图替换
# 用优化的实现替换子图
class OptimizedConv(nn.Module):
def forward(self, x):
# 可以替换为更高效的融合实现
return torch.ops.aten.conv2d(x, self.weight, self.bias)8.3 函数式图表示
JAX使用纯函数式表示计算图:
from jax import make_jaxpr
def f(x, y):
return x * y + jnp.sum(x)
# JAXPR表示
print(make_jaxpr(f)(jnp.array([1., 2.]), jnp.array([3., 4.])))
# { lambda ; a:i32[2] b:i32[2].
# let c:f32[2] = mul a b
# let d:f32 = reduce_sum[ axes=(0,) ] c
# let e:f32[2] = add c d
# let f:f32[2] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(2,) ] d
# let g:f32[2] = add e f
# let h:f32[2] = add g a
# let i:f32[2] = mul h b
# let j:f32 = reduce_sum[ axes=(0,) ] i
# let k:f32[2] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(2,) ] j
# let l:f32[2] = add i k
# result l }9. 总结
9.1 计算图核心要点
- DAG结构:无环有向图是自动微分的基础
- 动态vs静态:PyTorch动态图 vs JAX/TF静态图
- 拓扑排序:保证前向和反向的正确执行顺序
- 梯度累积:多路径梯度合并机制
- 图优化:公共子表达式消除、常量折叠等