PyTorch Autograd实现
PyTorch的torch.autograd是深度学习领域最广泛使用的自动微分引擎之一。本专题深入分析其内部架构、核心组件与执行流程。
1. PyTorch Autograd架构概览
1.1 整体架构
┌─────────────────────────────────────────────────────────┐
│ Python API Layer │
│ torch.autograd.backward() / grad() │
└─────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────┐
│ Python/C++ Bridge │
│ torch/csrc/autograd/python_anomaly_mode.cpp │
└─────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────┐
│ C++ Autograd Core │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Function │ │ Engine │ │ Node │ │
│ │ (基类) │ │ (执行器) │ │ (任务) │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
└─────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────┐
│ CUDA/CPU Backend │
│ autograd_variable.cpp / autograd_cuda.cpp │
└─────────────────────────────────────────────────────────┘
1.2 核心组件
| 组件 | 位置 | 功能 |
|---|---|---|
torch.autograd | Python | 顶层API |
Variable | C++ | 梯度上下文管理 |
Function | Python/C++ | 前向/反向计算 |
Engine | C++ | 图执行引擎 |
NodeTask | C++ | 单节点任务 |
GraphTask | C++ | 整图任务 |
2. Function类与backward机制
2.1 torch.autograd.Function
Function是定义自定义微分操作的核心基类:
class Function:
@staticmethod
def forward(ctx, *args, **kwargs):
"""
前向计算:执行实际运算
返回计算结果
"""
# 保存用于反向的信息到ctx
ctx.save_for_backward(*tensors_to_save)
ctx.some_metadata = some_value
return output
@staticmethod
def backward(ctx, *grad_outputs):
"""
反向计算:计算梯度
返回每个输入的梯度
"""
saved_tensors = ctx.saved_tensors
# 使用保存的信息计算梯度
grad_input = ...
return grad_input2.2 梯度保存机制
ctx.save_for_backward():
class Sigmoid(Function):
@staticmethod
def forward(ctx, input):
output = torch.sigmoid(input)
# 保存反向需要的张量
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
# Sigmoid导数: σ(x)(1-σ(x))
grad_input = grad_output * output * (1 - output)
return grad_inputctx.mark_dirty() 和 ctx.mark_non_differentiable():
class InplaceReLU(Function):
@staticmethod
def forward(ctx, input):
ctx.mark_dirty(input) # 输入会被原地修改
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()
class Detach(Function):
@staticmethod
def forward(ctx, input):
ctx.mark_non_differentiable(input) # 标记不可微分
return input.detach()2.3 Python层自动生成backward
对于大多数操作,PyTorch自动生成backward函数:
# 实际流程
class Sigmoid:
@staticmethod
def forward(ctx, input):
output = torch.sigmoid(input)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
return grad_output * output * (1 - output)3. 计算图执行流程
3.1 backward()的完整调用链
# Python层入口
tensor.backward()
│
▼
torch.autograd.backward(tensors, grad_tensors, ...)
│
▼
Variable._execution_engine.run_backward(
tensors, grad_tensors_, retain_graph, create_graph, inputs
)
│
▼
C++ Engine.execute()
│
▼
GraphTask.execute()
│
▼
NodeTask.execute() (for each gradient function)
│
▼
AccumulateGrad.update_grad()3.2 C++ Engine的核心实现
Engine.execute() 的伪代码:
void Engine::execute(const edge_list& roots,
const variable_list& grad_outputs,
bool retain_graph,
bool create_graph,
bool input_edge_requires_grad) {
// 1. 创建GraphTask
std::shared_ptr<GraphTask> graph_task =
std::make_shared<GraphTask>(
retain_graph, create_graph, input_edge_requires_grad
);
// 2. 初始化根节点
for (const auto& root : roots) {
auto& func = root.function;
// 设置输出梯度
graph_task->add_next_edge(root);
}
// 3. 执行图遍历
graph_task->execute();
}
void GraphTask::execute() {
// 4. 拓扑排序执行
while (! heap.empty()) {
// 获取拓扑编号最小的节点
auto node = heap.top();
heap.pop();
// 执行节点
node->run_function();
// 将梯度写入输入节点
for (const auto& edge : node->next_edges()) {
if (edge.is_valid()) {
graph_task->add_next_edge(edge);
}
}
}
}3.3 梯度累积机制
AccumulateGrad是特殊的Function节点,负责累积叶子节点的梯度:
class AccumulateGrad : public Function {
variable gradient;
void apply(const variable& grad) {
if (gradient.defined()) {
// 多个后继节点贡献的梯度累积
gradient = gradient + grad;
} else {
gradient = grad;
}
}
};为什么需要累积:
forward computation
┌──────────────────┐
│ │
▼ │
out ────────┬───────┘
│
▼ backward
┌──────────────┬──────────────┐
▼ ▼ ▼
ReLU Sigmoid Softmax
│ │ │
▼ ▼ ▼
x1 ◄───────── x ◄─────────── x
(same input)
的梯度来自三个路径的贡献,需要累积。
4. PyTorch Autograd优化技术
4.1 拓扑编号优化
topological_nr_(拓扑编号)记录从当前节点到根节点的最长路径长度:
// 拓扑编号的作用:跳过不必要的检查
if (node->topological_nr() < exec_info.needed_batch_nr) {
// 这个节点的结果不需要用于梯度计算
continue;
}4.2 选择性梯度计算
# 只计算特定输入的梯度
loss.backward()
x_grad = x.grad # 正确
# 或者使用torch.autograd.grad()
grads = torch.autograd.grad(
outputs=loss,
inputs=[w1, b1, w2, b2], # 只对这些参数求梯度
retain_graph=True
)4.3 梯度累积(用于大batch训练)
# 分多个micro-batch累积梯度
for batch in dataloader:
loss = model(batch)
loss.backward() # 梯度累积到grad中
# 不调用optimizer.step()
optimizer.step() # 一次性更新
optimizer.zero_grad()4.4 图执行优化
# 避免不必要的图构建
with torch.no_grad():
# 此区域内的计算不构建计算图
predictions = model(inputs)
# 或者使用.detach()
predictions = model(inputs).detach()5. 原地操作处理
5.1 原地操作的危险性
x = torch.tensor([1.0], requires_grad=True)
y = x * 2
x.add_(1) # 原地修改x
# 警告:原地操作可能破坏计算图5.2 mark_dirty机制
class MyInplace(Function):
@staticmethod
def forward(ctx, input):
ctx.mark_dirty(input) # 标记input会被修改
input.add_(1)
return input
@staticmethod
def backward(ctx, grad_output):
return grad_output # 输入的梯度直接来自输出5.3 版本号追踪
PyTorch通过版本号检测原地修改:
x = torch.tensor([1.0], requires_grad=True)
print(x._version) # 0
y = x * 2
print(x._version) # 仍然是0
x.add_(1) # 原地修改
print(x._version) # 1(版本号增加)
# 如果x被用于计算且版本不匹配,会报错6. 高级功能
6.1 自定义Function完整示例
import torch
from torch.autograd import Function
class CustomLinear(Function):
@staticmethod
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# 需要输入梯度
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
# 需要权重梯度
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
# 需要偏置梯度
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
# 使用
linear = CustomLinear.apply
output = linear(x, weight, bias)6.2 函数模式grad()
# grad()返回梯度,不修改叶子节点
x = torch.tensor([1.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
def func(x, y):
return x ** 2 + y ** 2
# 计算∂func/∂x at (x, y)
grad_x = torch.autograd.grad(
outputs=func(x, y),
inputs=x
)
# 计算∂func/∂(x, y)
grad_xy = torch.autograd.grad(
outputs=func(x, y),
inputs=[x, y]
)6.3 高阶导数
x = torch.tensor([2.0], requires_grad=True)
y = x ** 3
# 一阶导数: 3x² = 12
grad1 = torch.autograd.grad(y, x, create_graph=True)[0]
# 二阶导数: 6x = 12
grad2 = torch.autograd.grad(grad1, x)[0]
print(f"一阶导数: {grad1.item()}") # 12.0
print(f"二阶导数: {grad2.item()}") # 12.06.4 梯度检查(数值校验)
from torch.autograd import gradcheck
# 创建验证函数
def my_func(x):
return x ** 2 + x.sin()
# 梯度检查
x = torch.randn(3, 3, requires_grad=True, dtype=torch.double)
print(gradcheck(my_func, x)) # True if pass7. 常见陷阱与调试
7.1 梯度不计算
# 问题:没有指定requires_grad
x = torch.tensor([1.0]) # requires_grad=False
y = x ** 2
y.backward() # 错误:x没有梯度
# 解决
x = torch.tensor([1.0], requires_grad=True)7.2 图被提前释放
# 问题:叶子节点被用于后续计算
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2
z = y.detach() # 切断与x的连接
z.backward() # x没有收到梯度
# 解决:使用detach()时注意
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2
z = y.detach() # 正确使用detach
# 如果需要z但也想保留梯度,使用z = y + 07.3 retain_graph问题
# 问题:多次反向传播需要retain_graph
loss.backward() # 第一次
loss.backward() # 错误:图已被释放
# 解决
loss.backward(retain_graph=True)
loss.backward()7.4 NaN/Inf梯度诊断
# 使用anomaly模式检测
with torch.autograd.set_detect_anomaly(True):
loss = model(input)
loss.backward()
# 检查梯度
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: max={param.grad.max()}, min={param.grad.min()}")
if torch.isnan(param.grad).any():
print(f" NaN detected in {name}")8. 性能分析
8.1 Profiling自动微分
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
with_stack=True,
) as prof:
loss = model(input)
loss.backward()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))8.2 Autograd开销分析
| 操作类型 | Autograd开销占比 |
|---|---|
| 矩阵乘法 | ~5% |
| 激活函数 | ~10% |
| 内存操作 | ~15% |
| 反向图遍历 | ~70% |
8.3 减少Autograd开销
- 使用
torch.no_grad()进行推理 - 梯度检查点减少内存(见相关专题)
- 使用
.detach()切断不必要的梯度追踪 - 避免在计算图中使用Python控制流
9. 总结
9.1 PyTorch Autograd核心要点
- 动态计算图:每次前向传播实时构建
- Function基类:定义自定义微分操作的标准接口
- C++ Engine:高性能图执行引擎
- 梯度累积:支持多个后继节点的梯度合并
- 版本追踪:检测原地操作的潜在问题