PyTorch Autograd实现

PyTorch的torch.autograd是深度学习领域最广泛使用的自动微分引擎之一。本专题深入分析其内部架构、核心组件与执行流程。

1. PyTorch Autograd架构概览

1.1 整体架构

┌─────────────────────────────────────────────────────────┐
│                    Python API Layer                      │
│          torch.autograd.backward() / grad()              │
└─────────────────────────────────────────────────────────┘
                           │
                           ▼
┌─────────────────────────────────────────────────────────┐
│                    Python/C++ Bridge                     │
│              torch/csrc/autograd/python_anomaly_mode.cpp │
└─────────────────────────────────────────────────────────┘
                           │
                           ▼
┌─────────────────────────────────────────────────────────┐
│                    C++ Autograd Core                     │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐   │
│  │  Function   │  │   Engine    │  │   Node     │   │
│  │   (基类)    │  │  (执行器)   │  │  (任务)    │   │
│  └─────────────┘  └─────────────┘  └─────────────┘   │
└─────────────────────────────────────────────────────────┘
                           │
                           ▼
┌─────────────────────────────────────────────────────────┐
│                   CUDA/CPU Backend                      │
│            autograd_variable.cpp / autograd_cuda.cpp     │
└─────────────────────────────────────────────────────────┘

1.2 核心组件

组件位置功能
torch.autogradPython顶层API
VariableC++梯度上下文管理
FunctionPython/C++前向/反向计算
EngineC++图执行引擎
NodeTaskC++单节点任务
GraphTaskC++整图任务

2. Function类与backward机制

2.1 torch.autograd.Function

Function是定义自定义微分操作的核心基类:

class Function:
    @staticmethod
    def forward(ctx, *args, **kwargs):
        """
        前向计算:执行实际运算
        返回计算结果
        """
        # 保存用于反向的信息到ctx
        ctx.save_for_backward(*tensors_to_save)
        ctx.some_metadata = some_value
        return output
    
    @staticmethod
    def backward(ctx, *grad_outputs):
        """
        反向计算:计算梯度
        返回每个输入的梯度
        """
        saved_tensors = ctx.saved_tensors
        # 使用保存的信息计算梯度
        grad_input = ... 
        return grad_input

2.2 梯度保存机制

ctx.save_for_backward()

class Sigmoid(Function):
    @staticmethod
    def forward(ctx, input):
        output = torch.sigmoid(input)
        # 保存反向需要的张量
        ctx.save_for_backward(output)
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        output, = ctx.saved_tensors
        # Sigmoid导数: σ(x)(1-σ(x))
        grad_input = grad_output * output * (1 - output)
        return grad_input

ctx.mark_dirty() 和 ctx.mark_non_differentiable()

class InplaceReLU(Function):
    @staticmethod
    def forward(ctx, input):
        ctx.mark_dirty(input)  # 输入会被原地修改
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()
 
class Detach(Function):
    @staticmethod
    def forward(ctx, input):
        ctx.mark_non_differentiable(input)  # 标记不可微分
        return input.detach()

2.3 Python层自动生成backward

对于大多数操作,PyTorch自动生成backward函数:

# 实际流程
class Sigmoid:
    @staticmethod
    def forward(ctx, input):
        output = torch.sigmoid(input)
        ctx.save_for_backward(output)
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        output, = ctx.saved_tensors
        return grad_output * output * (1 - output)

3. 计算图执行流程

3.1 backward()的完整调用链

# Python层入口
tensor.backward()


torch.autograd.backward(tensors, grad_tensors, ...)


Variable._execution_engine.run_backward(
    tensors, grad_tensors_, retain_graph, create_graph, inputs
)


C++ Engine.execute()


GraphTask.execute()


NodeTask.execute() (for each gradient function)


AccumulateGrad.update_grad()

3.2 C++ Engine的核心实现

Engine.execute() 的伪代码:

void Engine::execute(const edge_list& roots,
                     const variable_list& grad_outputs,
                     bool retain_graph,
                     bool create_graph,
                     bool input_edge_requires_grad) {
    
    // 1. 创建GraphTask
    std::shared_ptr<GraphTask> graph_task = 
        std::make_shared<GraphTask>(
            retain_graph, create_graph, input_edge_requires_grad
        );
    
    // 2. 初始化根节点
    for (const auto& root : roots) {
        auto& func = root.function;
        // 设置输出梯度
        graph_task->add_next_edge(root);
    }
    
    // 3. 执行图遍历
    graph_task->execute();
}
 
void GraphTask::execute() {
    // 4. 拓扑排序执行
    while (! heap.empty()) {
        // 获取拓扑编号最小的节点
        auto node = heap.top();
        heap.pop();
        
        // 执行节点
        node->run_function();
        
        // 将梯度写入输入节点
        for (const auto& edge : node->next_edges()) {
            if (edge.is_valid()) {
                graph_task->add_next_edge(edge);
            }
        }
    }
}

3.3 梯度累积机制

AccumulateGrad是特殊的Function节点,负责累积叶子节点的梯度:

class AccumulateGrad : public Function {
    variable gradient;
    
    void apply(const variable& grad) {
        if (gradient.defined()) {
            // 多个后继节点贡献的梯度累积
            gradient = gradient + grad;
        } else {
            gradient = grad;
        }
    }
};

为什么需要累积

     forward computation
    ┌──────────────────┐
    │                  │
    ▼                  │
   out ────────┬───────┘
                │
                ▼ backward
    ┌──────────────┬──────────────┐
    ▼              ▼              ▼
   ReLU         Sigmoid        Softmax
    │              │              │
    ▼              ▼              ▼
    x1 ◄───────── x ◄─────────── x
                (same input)

的梯度来自三个路径的贡献,需要累积。

4. PyTorch Autograd优化技术

4.1 拓扑编号优化

topological_nr_(拓扑编号)记录从当前节点到根节点的最长路径长度:

// 拓扑编号的作用:跳过不必要的检查
if (node->topological_nr() < exec_info.needed_batch_nr) {
    // 这个节点的结果不需要用于梯度计算
    continue;
}

4.2 选择性梯度计算

# 只计算特定输入的梯度
loss.backward()
x_grad = x.grad  # 正确
 
# 或者使用torch.autograd.grad()
grads = torch.autograd.grad(
    outputs=loss,
    inputs=[w1, b1, w2, b2],  # 只对这些参数求梯度
    retain_graph=True
)

4.3 梯度累积(用于大batch训练)

# 分多个micro-batch累积梯度
for batch in dataloader:
    loss = model(batch)
    loss.backward()  # 梯度累积到grad中
    # 不调用optimizer.step()
 
optimizer.step()  # 一次性更新
optimizer.zero_grad()

4.4 图执行优化

# 避免不必要的图构建
with torch.no_grad():
    # 此区域内的计算不构建计算图
    predictions = model(inputs)
 
# 或者使用.detach()
predictions = model(inputs).detach()

5. 原地操作处理

5.1 原地操作的危险性

x = torch.tensor([1.0], requires_grad=True)
y = x * 2
x.add_(1)  # 原地修改x
# 警告:原地操作可能破坏计算图

5.2 mark_dirty机制

class MyInplace(Function):
    @staticmethod
    def forward(ctx, input):
        ctx.mark_dirty(input)  # 标记input会被修改
        input.add_(1)
        return input
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output  # 输入的梯度直接来自输出

5.3 版本号追踪

PyTorch通过版本号检测原地修改:

x = torch.tensor([1.0], requires_grad=True)
print(x._version)  # 0
 
y = x * 2
print(x._version)  # 仍然是0
 
x.add_(1)  # 原地修改
print(x._version)  # 1(版本号增加)
 
# 如果x被用于计算且版本不匹配,会报错

6. 高级功能

6.1 自定义Function完整示例

import torch
from torch.autograd import Function
 
class CustomLinear(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        
        # 需要输入梯度
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        
        # 需要权重梯度
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        
        # 需要偏置梯度
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        
        return grad_input, grad_weight, grad_bias
 
# 使用
linear = CustomLinear.apply
output = linear(x, weight, bias)

6.2 函数模式grad()

# grad()返回梯度,不修改叶子节点
x = torch.tensor([1.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
 
def func(x, y):
    return x ** 2 + y ** 2
 
# 计算∂func/∂x at (x, y)
grad_x = torch.autograd.grad(
    outputs=func(x, y),
    inputs=x
)
 
# 计算∂func/∂(x, y)
grad_xy = torch.autograd.grad(
    outputs=func(x, y),
    inputs=[x, y]
)

6.3 高阶导数

x = torch.tensor([2.0], requires_grad=True)
y = x ** 3
 
# 一阶导数: 3x² = 12
grad1 = torch.autograd.grad(y, x, create_graph=True)[0]
 
# 二阶导数: 6x = 12
grad2 = torch.autograd.grad(grad1, x)[0]
 
print(f"一阶导数: {grad1.item()}")  # 12.0
print(f"二阶导数: {grad2.item()}")  # 12.0

6.4 梯度检查(数值校验)

from torch.autograd import gradcheck
 
# 创建验证函数
def my_func(x):
    return x ** 2 + x.sin()
 
# 梯度检查
x = torch.randn(3, 3, requires_grad=True, dtype=torch.double)
print(gradcheck(my_func, x))  # True if pass

7. 常见陷阱与调试

7.1 梯度不计算

# 问题:没有指定requires_grad
x = torch.tensor([1.0])  # requires_grad=False
y = x ** 2
y.backward()  # 错误:x没有梯度
 
# 解决
x = torch.tensor([1.0], requires_grad=True)

7.2 图被提前释放

# 问题:叶子节点被用于后续计算
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2
z = y.detach()  # 切断与x的连接
z.backward()  # x没有收到梯度
 
# 解决:使用detach()时注意
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2
z = y.detach()  # 正确使用detach
# 如果需要z但也想保留梯度,使用z = y + 0

7.3 retain_graph问题

# 问题:多次反向传播需要retain_graph
loss.backward()  # 第一次
loss.backward()  # 错误:图已被释放
 
# 解决
loss.backward(retain_graph=True)
loss.backward()

7.4 NaN/Inf梯度诊断

# 使用anomaly模式检测
with torch.autograd.set_detect_anomaly(True):
    loss = model(input)
    loss.backward()
 
# 检查梯度
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"{name}: max={param.grad.max()}, min={param.grad.min()}")
        if torch.isnan(param.grad).any():
            print(f"  NaN detected in {name}")

8. 性能分析

8.1 Profiling自动微分

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    record_shapes=True,
    with_stack=True,
) as prof:
    loss = model(input)
    loss.backward()
 
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

8.2 Autograd开销分析

操作类型Autograd开销占比
矩阵乘法~5%
激活函数~10%
内存操作~15%
反向图遍历~70%

8.3 减少Autograd开销

  1. 使用torch.no_grad() 进行推理
  2. 梯度检查点减少内存(见相关专题)
  3. 使用.detach() 切断不必要的梯度追踪
  4. 避免在计算图中使用Python控制流

9. 总结

9.1 PyTorch Autograd核心要点

  1. 动态计算图:每次前向传播实时构建
  2. Function基类:定义自定义微分操作的标准接口
  3. C++ Engine:高性能图执行引擎
  4. 梯度累积:支持多个后继节点的梯度合并
  5. 版本追踪:检测原地操作的潜在问题

9.2 相关专题

参考资料