代理梯度学习:解决SNN不可微问题
概述
脉冲神经网络的脉冲发放函数是不可微的:
这使得传统的反向传播算法无法直接应用于SNN训练。**代理梯度(Surrogate Gradient)**方法通过用一个光滑可微的函数近似脉冲函数的梯度来解决这一问题。1
1. 问题形式化
1.1 脉冲函数的梯度困境
真实脉冲函数 (不可微) 代理梯度 (可微)
S(V) S'(V)
│ 1 ┬───── ↑
│ │ │ ╭──
│ │ │ ╱
│───┼────────────→ V │ ╱
│0 │ V_th │╱
└───┼─────────────→ V ─┴──────────→ V
V_th
导数 = 0 (几乎处处) 导数 = 平滑峰值
数学问题:
1.2 损失函数与梯度
训练SNN的目标是最小化损失函数:
其中 是网络参数。梯度计算需要链式法则:
问题在于 无法计算。
2. 代理梯度方法
2.1 核心思想
使用一个光滑可微的函数 来近似脉冲函数 的梯度:
其中 是控制代理梯度形状的参数。
┌─────────────────────────────────────────────────────────────┐
│ 代理梯度示意图 │
├─────────────────────────────────────────────────────────────┤
│ │
│ σ(V) │
│ │ arctan │
│ │ ╱ │
│ │ ╱ ─── sigmoid │
│ │ ╱ ╱ │
│ │ ╱ │
│ │╱ │
│ ─┼─────────────── V │
│ V_th │
│ │
│ 关键洞察:在 V ≈ V_th 附近提供梯度,其余地方为0 │
│ │
└─────────────────────────────────────────────────────────────┘
2.2 常见代理函数
Arctan代理梯度
最常用的代理梯度函数之一:
def surrogate_atan(V, V_th, alpha):
"""Arctan代理梯度"""
x = V - V_th
return (alpha / torch.pi) / (alpha**2 * x**2 + 1)可视化:
alpha=1: ╭──
╱
╱
────────────●────────→ V
V_th
alpha=2: ╭───
╱
╱
────────────●────────→ V
V_th
alpha=0.5: ╭─
╱
╱
───────────●──────────→ V
V_th
Sigmoid代理梯度
def surrogate_sigmoid(V, V_th, beta=1.0):
"""Sigmoid代理梯度"""
x = V - V_th
sig = torch.sigmoid(x / beta)
return sig * (1 - sig)Fast Sigmoid代理梯度
def surrogate_fast_sigmoid(V, V_th, alpha=1.0):
"""快速Sigmoid代理梯度"""
x = torch.abs(V - V_th)
return alpha / (alpha + x)Piecewise Linear代理梯度
def surrogate_pwl(V, V_th, alpha=2.0):
"""分段线性代理梯度"""
x = torch.abs(V - V_th)
return alpha * (x <= 1/alpha).float()2.3 代理函数对比
| 代理函数 | 公式 | 计算复杂度 | 梯度强度 | 梯度宽度 |
|---|---|---|---|---|
| Arctan | 中 | 中 | 可调 | |
| Sigmoid | 低 | 峰值为0.25 | 较窄 | |
| Fast Sigmoid | $\frac{\alpha}{\alpha + | x | }$ | 最低 |
| Piecewise Linear | 阶梯+线性 | 最低 | 恒定 | 固定 |
3. 代理梯度反向传播
3.1 前向传播与反向传播
import torch
import torch.nn as nn
import torch.nn.functional as F
class SurrogateGradientFunction(torch.autograd.Function):
"""
自定义代理梯度反向传播
"""
@staticmethod
def forward(ctx, V, V_th, alpha, spike_fn='arctan'):
# 保存上下文用于反向传播
ctx.save_for_backward(V)
ctx.V_th = V_th
ctx.alpha = alpha
ctx.spike_fn = spike_fn
# 前向:硬阈值
return (V >= V_th).float()
@staticmethod
def backward(ctx, grad_output):
V, = ctx.saved_tensors
V_th = ctx.V_th
alpha = ctx.alpha
# 反向:代理梯度
x = V - V_th
if ctx.spike_fn == 'arctan':
grad = (alpha / torch.pi) / (alpha**2 * x**2 + 1)
elif ctx.spike_fn == 'sigmoid':
sig = torch.sigmoid(x / alpha)
grad = sig * (1 - sig) / alpha
elif ctx.spike_fn == 'fast_sigmoid':
grad = alpha / (alpha + torch.abs(x))**2
else:
raise ValueError(f"Unknown spike function: {ctx.spike_fn}")
# 梯度裁剪
grad = torch.clamp(grad, min=0, max=1)
return grad * grad_output, None, None, None
class SurrogateGradientLinear(nn.Module):
"""
支持代理梯度学习的线性层
"""
def __init__(self, in_features, out_features, bias=True,
surrogate='arctan', alpha=2.0):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias)
self.V_th = 1.0 # 固定阈值
self.surrogate = surrogate
self.alpha = alpha
def forward(self, x):
# 线性变换
u = self.linear(x)
# 添加膜电位衰减
# V[t] = β * V[t-1] + u[t]
# 这里简化处理:V = u
V = u
# 脉冲发放(使用代理梯度)
spike = SurrogateGradientFunction.apply(
V, self.V_th, self.alpha, self.surrogate
)
return spike
class SNNLayer(nn.Module):
"""
完整的SNN层,支持时间步展开
"""
def __init__(self, in_features, out_features,
tau_mem=10.0, tau_syn=5.0,
V_th=1.0, V_reset=0.0,
surrogate='arctan', alpha=2.0,
dt=1.0):
super().__init__()
self.tau_mem = tau_mem
self.tau_syn = tau_syn
self.V_th = V_th
self.V_reset = V_reset
self.surrogate = surrogate
self.alpha = alpha
self.dt = dt
# 衰减因子
self.beta_mem = torch.exp(torch.tensor(-dt / tau_mem))
self.beta_syn = torch.exp(torch.tensor(-dt / tau_syn))
# 权重和偏置
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
self.bias = nn.Parameter(torch.zeros(out_features))
def surrogate_gradient(self, V):
"""计算代理梯度"""
x = V - self.V_th
if self.surrogate == 'arctan':
return (self.alpha / torch.pi) / (self.alpha**2 * x**2 + 1)
elif self.surrogate == 'sigmoid':
sig = torch.sigmoid(x / self.alpha)
return sig * (1 - sig)
elif self.surrogate == 'fast_sigmoid':
return self.alpha / (self.alpha + torch.abs(x) + 1e-8)
return torch.zeros_like(V)
def forward(self, x_t, V_mem, I_syn):
"""
单步前向传播
x_t: 当前时刻的输入 (batch, in_features)
V_mem: 膜电位状态 (batch, out_features)
I_syn: 突触电流状态 (batch, out_features)
"""
# 突触电流更新
I_syn_new = self.beta_syn * I_syn + (1 - self.beta_syn) * x_t @ self.weight.t() + self.bias
# 膜电位更新
dV = (-(V_mem - self.V_reset) + I_syn_new) / self.tau_mem
V_mem_new = V_mem + dV * self.dt
# 脉冲发放(代理梯度)
spike = SurrogateGradientFunction.apply(
V_mem_new, self.V_th, self.alpha, self.surrogate
)
# 膜电位重置
V_mem_new = torch.where(
V_mem_new >= self.V_th,
self.V_reset,
V_mem_new
)
return V_mem_new, spike, I_syn_new3.2 时间反向传播(BPTT)
SNN通常在时间维度展开,然后使用BPTT训练:
class SNNBPTT(nn.Module):
"""
使用BPTT训练的完整SNN
"""
def __init__(self, input_size, hidden_size, output_size,
num_steps=10, tau_mem=10.0):
super().__init__()
self.num_steps = num_steps
self.hidden_size = hidden_size
self.fc1 = SNNLayer(input_size, hidden_size, tau_mem=tau_mem)
self.fc2 = SNNLayer(hidden_size, output_size, tau_mem=tau_mem)
# 输出层使用常规激活
self.fc_out = nn.Linear(output_size, output_size)
def forward(self, x):
"""
x: (batch, input_size)
"""
batch_size = x.shape[0]
# 初始化状态
V1 = torch.zeros(batch_size, self.hidden_size, device=x.device)
I1 = torch.zeros(batch_size, self.hidden_size, device=x.device)
V2 = torch.zeros(batch_size, self.output_size, device=x.device)
I2 = torch.zeros(batch_size, self.output_size, device=x.device)
# 收集所有时间步的输出
spike_outputs = []
for t in range(self.num_steps):
# 输入广播到所有时间步
x_t = x # 或 x[:, t] 如果输入有时序
# 第一层
V1, spike1, I1 = self.fc1(x_t, V1, I1)
# 第二层
V2, spike2, I2 = self.fc2(spike1, V2, I2)
spike_outputs.append(spike2)
# 时序池化:取最后一层所有时间步的平均发放率
spike_rates = torch.stack(spike_outputs, dim=0).mean(dim=0)
# 分类输出
output = self.fc_out(spike_rates)
return output
class SNNClassifier(nn.Module):
"""完整的SNN分类器"""
def __init__(self, input_size, hidden_size, output_size,
num_steps=10):
super().__init__()
self.snn = SNNBPTT(input_size, hidden_size, output_size, num_steps)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x, labels=None):
logits = self.snn(x)
if labels is not None:
loss = self.criterion(logits, labels)
return loss, logits
return logits
def training_step(self, batch):
x, labels = batch
loss, logits = self(x, labels)
return loss4. 代理梯度理论分析
4.1 收敛性保证
定理(代理梯度收敛性):在适当条件下,使用代理梯度的SNN训练收敛到临界点。
条件:
- 损失函数 是光滑的
- 代理梯度函数 满足
- 学习率满足 ,其中 是Lipschitz常数
4.2 梯度尺度问题
代理梯度的大小影响训练稳定性:
| 代理函数 | 梯度尺度 | 建议学习率调整 |
|---|---|---|
| Arctan | at peak | 正常 |
| Sigmoid | at peak | 减小 |
| Fast Sigmoid | 变化 | 减小 |
4.3 梯度流分析
def analyze_gradient_flow(model):
"""分析代理梯度对梯度流的影响"""
grad_stats = {
'layer1_weight': [],
'layer2_weight': [],
'proxy_grad_mean': []
}
for name, param in model.named_parameters():
if param.grad is not None:
grad_stats[name.split('.')[0] + '_' + name.split('.')[1]].append(
param.grad.abs().mean().item()
)
return grad_stats5. 训练技巧与最佳实践
5.1 初始化
def initialize_snn_weights(layer):
"""SNN权重初始化"""
nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
# 调整阈值周围的权重初始化以促进早期脉冲发放
with torch.no_grad():
# 使初始输出分布更适合脉冲发放
layer.weight.data *= 1.55.2 阈值调整
class AdaptiveThreshold(nn.Module):
"""自适应阈值模块"""
def __init__(self, initial_th=1.0, min_th=0.5, max_th=2.0):
super().__init__()
self.threshold = nn.Parameter(torch.tensor(initial_th))
self.min_th = min_th
self.max_th = max_th
def forward(self):
# 阈值裁剪
return torch.clamp(self.threshold, self.min_th, self.max_th)5.3 替代方法:Straight-Through Estimator
最简单的代理梯度方法:
class STEFunction(torch.autograd.Function):
"""Straight-Through Estimator"""
@staticmethod
def forward(ctx, V, V_th):
return (V >= V_th).float()
@staticmethod
def backward(ctx, grad_output):
return grad_output, None # 直接传递梯度6. 完整训练示例
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
def train_snn_mnist():
"""MNIST分类的完整SNN训练流程"""
# 模型定义
class SimpleSNN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.lif1 = LIFNeuron(tau_mem=10.0, V_th=1.0)
self.fc2 = nn.Linear(256, 10)
self.lif2 = LIFNeuron(tau_mem=10.0, V_th=1.0)
def forward(self, x, num_steps=10):
batch_size = x.shape[0]
# 初始化状态
V1 = torch.zeros(batch_size, 256, device=x.device)
V2 = torch.zeros(batch_size, 10, device=x.device)
spike_counts = torch.zeros(batch_size, 10, device=x.device)
for _ in range(num_steps):
# 第一层
u1 = self.fc1(x)
V1, spike1 = self.lif1(V1, u1)
# 第二层
u2 = self.fc2(spike1)
V2, spike2 = self.lif2(V2, u2)
spike_counts += spike2
# 发放率分类
rates = spike_counts / num_steps
return rates
# 训练循环
model = SimpleSNN().to('cuda')
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.ToTensor()),
batch_size=128, shuffle=True
)
for epoch in range(10):
model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to('cuda'), target.to('cuda')
data = data.view(data.size(0), -1) # 展平
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
print(f'Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}, '
f'Acc={100*correct/total:.2f}%')
return model
if __name__ == '__main__':
train_snn_mnist()7. 总结
代理梯度学习是训练SNN的核心技术:
| 要点 | 说明 |
|---|---|
| 核心问题 | 脉冲函数不可微,无法直接反向传播 |
| 解决方案 | 用光滑函数近似梯度 |
| 常见方法 | Arctan、Sigmoid、Fast Sigmoid |
| 理论保证 | 在适当条件下收敛 |
| 实践技巧 | 初始化、阈值调整、学习率设置 |
参考
Footnotes
-
Neftci, E. O., Mostafa, H., & Zenke, F. (2019). Surrogate Gradient Learning in Spiking Neural Networks. IEEE Signal Processing Magazine, 36(6), 61-67. ↩