代理梯度学习:解决SNN不可微问题

概述

脉冲神经网络的脉冲发放函数是不可微的

这使得传统的反向传播算法无法直接应用于SNN训练。**代理梯度(Surrogate Gradient)**方法通过用一个光滑可微的函数近似脉冲函数的梯度来解决这一问题。1


1. 问题形式化

1.1 脉冲函数的梯度困境

真实脉冲函数 (不可微)          代理梯度 (可微)
                                 
S(V)                              S'(V)
  │ 1 ┬─────                          ↑ 
  │   │                               │   ╭──
  │   │                               │  ╱
  │───┼────────────→ V               │ ╱
  │0  │ V_th                          │╱
  └───┼─────────────→ V             ─┴──────────→ V
      V_th
       
  导数 = 0 (几乎处处)              导数 = 平滑峰值

数学问题

1.2 损失函数与梯度

训练SNN的目标是最小化损失函数:

其中 是网络参数。梯度计算需要链式法则:

问题在于 无法计算。


2. 代理梯度方法

2.1 核心思想

使用一个光滑可微的函数 来近似脉冲函数 的梯度:

其中 是控制代理梯度形状的参数。

┌─────────────────────────────────────────────────────────────┐
│                    代理梯度示意图                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│     σ(V)                                                    │
│       │        arctan                                       │
│       │      ╱                                               │
│       │    ╱  ─── sigmoid                                   │
│       │  ╱ ╱                                                │
│       │ ╱                                                    │
│       │╱                                                    │
│       ─┼─────────────── V                                  │
│        V_th                                                   │
│                                                             │
│  关键洞察:在 V ≈ V_th 附近提供梯度,其余地方为0             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2.2 常见代理函数

Arctan代理梯度

最常用的代理梯度函数之一:

def surrogate_atan(V, V_th, alpha):
    """Arctan代理梯度"""
    x = V - V_th
    return (alpha / torch.pi) / (alpha**2 * x**2 + 1)

可视化

alpha=1:     ╭──
             ╱
            ╱
────────────●────────→ V
           V_th

alpha=2:      ╭───
             ╱
            ╱
────────────●────────→ V
           V_th

alpha=0.5:    ╭─
            ╱
           ╱
───────────●──────────→ V
          V_th

Sigmoid代理梯度

def surrogate_sigmoid(V, V_th, beta=1.0):
    """Sigmoid代理梯度"""
    x = V - V_th
    sig = torch.sigmoid(x / beta)
    return sig * (1 - sig)

Fast Sigmoid代理梯度

def surrogate_fast_sigmoid(V, V_th, alpha=1.0):
    """快速Sigmoid代理梯度"""
    x = torch.abs(V - V_th)
    return alpha / (alpha + x)

Piecewise Linear代理梯度

def surrogate_pwl(V, V_th, alpha=2.0):
    """分段线性代理梯度"""
    x = torch.abs(V - V_th)
    return alpha * (x <= 1/alpha).float()

2.3 代理函数对比

代理函数公式计算复杂度梯度强度梯度宽度
Arctan可调
Sigmoid峰值为0.25较窄
Fast Sigmoid$\frac{\alpha}{\alpha +x}$最低
Piecewise Linear阶梯+线性最低恒定固定

3. 代理梯度反向传播

3.1 前向传播与反向传播

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class SurrogateGradientFunction(torch.autograd.Function):
    """
    自定义代理梯度反向传播
    """
    @staticmethod
    def forward(ctx, V, V_th, alpha, spike_fn='arctan'):
        # 保存上下文用于反向传播
        ctx.save_for_backward(V)
        ctx.V_th = V_th
        ctx.alpha = alpha
        ctx.spike_fn = spike_fn
        
        # 前向:硬阈值
        return (V >= V_th).float()
    
    @staticmethod
    def backward(ctx, grad_output):
        V, = ctx.saved_tensors
        V_th = ctx.V_th
        alpha = ctx.alpha
        
        # 反向:代理梯度
        x = V - V_th
        if ctx.spike_fn == 'arctan':
            grad = (alpha / torch.pi) / (alpha**2 * x**2 + 1)
        elif ctx.spike_fn == 'sigmoid':
            sig = torch.sigmoid(x / alpha)
            grad = sig * (1 - sig) / alpha
        elif ctx.spike_fn == 'fast_sigmoid':
            grad = alpha / (alpha + torch.abs(x))**2
        else:
            raise ValueError(f"Unknown spike function: {ctx.spike_fn}")
        
        # 梯度裁剪
        grad = torch.clamp(grad, min=0, max=1)
        
        return grad * grad_output, None, None, None
 
 
class SurrogateGradientLinear(nn.Module):
    """
    支持代理梯度学习的线性层
    """
    def __init__(self, in_features, out_features, bias=True, 
                 surrogate='arctan', alpha=2.0):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias)
        self.V_th = 1.0  # 固定阈值
        self.surrogate = surrogate
        self.alpha = alpha
    
    def forward(self, x):
        # 线性变换
        u = self.linear(x)
        
        # 添加膜电位衰减
        # V[t] = β * V[t-1] + u[t]
        # 这里简化处理:V = u
        V = u
        
        # 脉冲发放(使用代理梯度)
        spike = SurrogateGradientFunction.apply(
            V, self.V_th, self.alpha, self.surrogate
        )
        
        return spike
 
 
class SNNLayer(nn.Module):
    """
    完整的SNN层,支持时间步展开
    """
    def __init__(self, in_features, out_features, 
                 tau_mem=10.0, tau_syn=5.0,
                 V_th=1.0, V_reset=0.0,
                 surrogate='arctan', alpha=2.0,
                 dt=1.0):
        super().__init__()
        self.tau_mem = tau_mem
        self.tau_syn = tau_syn
        self.V_th = V_th
        self.V_reset = V_reset
        self.surrogate = surrogate
        self.alpha = alpha
        self.dt = dt
        
        # 衰减因子
        self.beta_mem = torch.exp(torch.tensor(-dt / tau_mem))
        self.beta_syn = torch.exp(torch.tensor(-dt / tau_syn))
        
        # 权重和偏置
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
        self.bias = nn.Parameter(torch.zeros(out_features))
    
    def surrogate_gradient(self, V):
        """计算代理梯度"""
        x = V - self.V_th
        if self.surrogate == 'arctan':
            return (self.alpha / torch.pi) / (self.alpha**2 * x**2 + 1)
        elif self.surrogate == 'sigmoid':
            sig = torch.sigmoid(x / self.alpha)
            return sig * (1 - sig)
        elif self.surrogate == 'fast_sigmoid':
            return self.alpha / (self.alpha + torch.abs(x) + 1e-8)
        return torch.zeros_like(V)
    
    def forward(self, x_t, V_mem, I_syn):
        """
        单步前向传播
        x_t: 当前时刻的输入 (batch, in_features)
        V_mem: 膜电位状态 (batch, out_features)
        I_syn: 突触电流状态 (batch, out_features)
        """
        # 突触电流更新
        I_syn_new = self.beta_syn * I_syn + (1 - self.beta_syn) * x_t @ self.weight.t() + self.bias
        
        # 膜电位更新
        dV = (-(V_mem - self.V_reset) + I_syn_new) / self.tau_mem
        V_mem_new = V_mem + dV * self.dt
        
        # 脉冲发放(代理梯度)
        spike = SurrogateGradientFunction.apply(
            V_mem_new, self.V_th, self.alpha, self.surrogate
        )
        
        # 膜电位重置
        V_mem_new = torch.where(
            V_mem_new >= self.V_th,
            self.V_reset,
            V_mem_new
        )
        
        return V_mem_new, spike, I_syn_new

3.2 时间反向传播(BPTT)

SNN通常在时间维度展开,然后使用BPTT训练:

class SNNBPTT(nn.Module):
    """
    使用BPTT训练的完整SNN
    """
    def __init__(self, input_size, hidden_size, output_size, 
                 num_steps=10, tau_mem=10.0):
        super().__init__()
        self.num_steps = num_steps
        self.hidden_size = hidden_size
        
        self.fc1 = SNNLayer(input_size, hidden_size, tau_mem=tau_mem)
        self.fc2 = SNNLayer(hidden_size, output_size, tau_mem=tau_mem)
        
        # 输出层使用常规激活
        self.fc_out = nn.Linear(output_size, output_size)
    
    def forward(self, x):
        """
        x: (batch, input_size)
        """
        batch_size = x.shape[0]
        
        # 初始化状态
        V1 = torch.zeros(batch_size, self.hidden_size, device=x.device)
        I1 = torch.zeros(batch_size, self.hidden_size, device=x.device)
        V2 = torch.zeros(batch_size, self.output_size, device=x.device)
        I2 = torch.zeros(batch_size, self.output_size, device=x.device)
        
        # 收集所有时间步的输出
        spike_outputs = []
        
        for t in range(self.num_steps):
            # 输入广播到所有时间步
            x_t = x  # 或 x[:, t] 如果输入有时序
            
            # 第一层
            V1, spike1, I1 = self.fc1(x_t, V1, I1)
            
            # 第二层
            V2, spike2, I2 = self.fc2(spike1, V2, I2)
            
            spike_outputs.append(spike2)
        
        # 时序池化:取最后一层所有时间步的平均发放率
        spike_rates = torch.stack(spike_outputs, dim=0).mean(dim=0)
        
        # 分类输出
        output = self.fc_out(spike_rates)
        
        return output
 
 
class SNNClassifier(nn.Module):
    """完整的SNN分类器"""
    def __init__(self, input_size, hidden_size, output_size, 
                 num_steps=10):
        super().__init__()
        self.snn = SNNBPTT(input_size, hidden_size, output_size, num_steps)
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(self, x, labels=None):
        logits = self.snn(x)
        if labels is not None:
            loss = self.criterion(logits, labels)
            return loss, logits
        return logits
    
    def training_step(self, batch):
        x, labels = batch
        loss, logits = self(x, labels)
        return loss

4. 代理梯度理论分析

4.1 收敛性保证

定理(代理梯度收敛性):在适当条件下,使用代理梯度的SNN训练收敛到临界点。

条件

  1. 损失函数 是光滑的
  2. 代理梯度函数 满足
  3. 学习率满足 ,其中 是Lipschitz常数

4.2 梯度尺度问题

代理梯度的大小影响训练稳定性:

代理函数梯度尺度建议学习率调整
Arctan at peak正常
Sigmoid at peak减小
Fast Sigmoid变化减小

4.3 梯度流分析

def analyze_gradient_flow(model):
    """分析代理梯度对梯度流的影响"""
    grad_stats = {
        'layer1_weight': [],
        'layer2_weight': [],
        'proxy_grad_mean': []
    }
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_stats[name.split('.')[0] + '_' + name.split('.')[1]].append(
                param.grad.abs().mean().item()
            )
    
    return grad_stats

5. 训练技巧与最佳实践

5.1 初始化

def initialize_snn_weights(layer):
    """SNN权重初始化"""
    nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
    # 调整阈值周围的权重初始化以促进早期脉冲发放
    with torch.no_grad():
        # 使初始输出分布更适合脉冲发放
        layer.weight.data *= 1.5

5.2 阈值调整

class AdaptiveThreshold(nn.Module):
    """自适应阈值模块"""
    def __init__(self, initial_th=1.0, min_th=0.5, max_th=2.0):
        super().__init__()
        self.threshold = nn.Parameter(torch.tensor(initial_th))
        self.min_th = min_th
        self.max_th = max_th
    
    def forward(self):
        # 阈值裁剪
        return torch.clamp(self.threshold, self.min_th, self.max_th)

5.3 替代方法:Straight-Through Estimator

最简单的代理梯度方法:

class STEFunction(torch.autograd.Function):
    """Straight-Through Estimator"""
    @staticmethod
    def forward(ctx, V, V_th):
        return (V >= V_th).float()
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None  # 直接传递梯度

6. 完整训练示例

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
 
def train_snn_mnist():
    """MNIST分类的完整SNN训练流程"""
    
    # 模型定义
    class SimpleSNN(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(784, 256)
            self.lif1 = LIFNeuron(tau_mem=10.0, V_th=1.0)
            self.fc2 = nn.Linear(256, 10)
            self.lif2 = LIFNeuron(tau_mem=10.0, V_th=1.0)
        
        def forward(self, x, num_steps=10):
            batch_size = x.shape[0]
            
            # 初始化状态
            V1 = torch.zeros(batch_size, 256, device=x.device)
            V2 = torch.zeros(batch_size, 10, device=x.device)
            
            spike_counts = torch.zeros(batch_size, 10, device=x.device)
            
            for _ in range(num_steps):
                # 第一层
                u1 = self.fc1(x)
                V1, spike1 = self.lif1(V1, u1)
                
                # 第二层
                u2 = self.fc2(spike1)
                V2, spike2 = self.lif2(V2, u2)
                
                spike_counts += spike2
            
            # 发放率分类
            rates = spike_counts / num_steps
            return rates
    
    # 训练循环
    model = SimpleSNN().to('cuda')
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, download=True,
                       transform=transforms.ToTensor()),
        batch_size=128, shuffle=True
    )
    
    for epoch in range(10):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to('cuda'), target.to('cuda')
            data = data.view(data.size(0), -1)  # 展平
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
        
        print(f'Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}, '
              f'Acc={100*correct/total:.2f}%')
    
    return model
 
if __name__ == '__main__':
    train_snn_mnist()

7. 总结

代理梯度学习是训练SNN的核心技术:

要点说明
核心问题脉冲函数不可微,无法直接反向传播
解决方案用光滑函数近似梯度
常见方法Arctan、Sigmoid、Fast Sigmoid
理论保证在适当条件下收敛
实践技巧初始化、阈值调整、学习率设置

参考

Footnotes

  1. Neftci, E. O., Mostafa, H., & Zenke, F. (2019). Surrogate Gradient Learning in Spiking Neural Networks. IEEE Signal Processing Magazine, 36(6), 61-67.