脉冲神经网络架构
概述
SNN架构设计需要解决两个核心问题:
- 如何利用脉冲的稀疏性实现高效计算
- 如何在深度网络中保持稳定的脉冲发放
本章介绍从卷积SNN到SNN-Transformer的主要架构类型。
1. 卷积SNN (ConvSNN)
1.1 架构原理
卷积SNN将传统CNN的卷积操作替换为脉冲版本:
ANN卷积: SNN卷积:
输入: 连续值 输入: 脉冲序列
↓ ↓
权重: 连续值 权重: 脉冲权重
↓ ↓
乘累加: 密集计算 事件驱动: 稀疏计算
↓ ↓
激活: ReLU/sigmoid 膜电位更新
↓
脉冲发放
1.2 脉冲卷积实现
class SpikingConv2d(nn.Module):
"""
脉冲卷积层
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, bias=False,
tau_mem=10.0, V_th=1.0):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels,
kernel_size, stride, padding, bias=bias)
self.tau_mem = tau_mem
self.V_th = V_th
self.V_reset = 0.0
# 衰减因子
self.beta = torch.exp(torch.tensor(-1.0 / tau_mem))
def forward(self, x_spike, V_mem=None):
"""
x_spike: (batch, C, H, W) - 当前时刻的输入脉冲
V_mem: 膜电位状态
"""
batch, c, h, w = x_spike.shape
if V_mem is None:
V_mem = torch.zeros(batch, self.conv.out_channels,
h_out, w_out, device=x_spike.device)
# 突触电流:卷积操作
I_syn = self.conv(x_spike)
# 膜电位更新
V_mem = self.beta * V_mem + I_syn
# 脉冲发放
spike = (V_mem >= self.V_th).float()
# 重置
V_mem = V_mem * (1 - spike) + self.V_reset * spike
return spike, V_mem1.3 卷积SNN网络结构
class ConvSNN(nn.Module):
"""
卷积SNN用于图像分类
"""
def __init__(self, num_classes=10):
super().__init__()
# 特征提取
self.conv1 = SpikingConv2d(1, 64, 3, padding=1, tau_mem=10.0)
self.conv2 = SpikingConv2d(64, 128, 3, padding=1, tau_mem=10.0)
self.conv3 = SpikingConv2d(128, 256, 3, padding=1, tau_mem=10.0)
# 池化
self.pool = nn.MaxPool2d(2, 2)
# 全连接层
self.fc = nn.Linear(256 * 4 * 4, num_classes)
# 可学习的膜电位(用于最后一层)
self.fc_lif = LIFNeuron(tau_mem=10.0)
def forward(self, x, num_steps=4):
"""
x: (batch, 1, 32, 32) - 归一化的图像
"""
batch_size = x.shape[0]
# 初始化膜电位
V1 = V2 = V3 = None
spike_counts = torch.zeros(batch_size, 256 * 4 * 4, device=x.device)
for t in range(num_steps):
# 广播静态输入为时序脉冲
x_t = x if t == 0 else torch.zeros_like(x)
# 或者使用泊松编码
# x_t = (torch.rand_like(x) < x).float()
# 卷积层1
s1, V1 = self.conv1(x_t, V1)
s1 = self.pool(s1)
# 卷积层2
s2, V2 = self.conv2(s1, V2)
s2 = self.pool(s2)
# 卷积层3
s3, V3 = self.conv3(s2, V3)
s3 = self.pool(s3)
# 累计脉冲
spike_counts += s3.view(batch_size, -1)
# 发放率
rates = spike_counts / num_steps
# 分类
out = self.fc(rates)
return out2. 循环SNN (RSNN)
2.1 循环连接的优势
循环连接让SNN能够:
- 建模时间依赖关系
- 增加网络容量而不增加参数
- 模拟工作记忆
┌──────────────────────────────────────────────────────┐
│ 循环SNN结构 │
├──────────────────────────────────────────────────────┤
│ │
│ x_t ──→ [Conv] ──→ s_t ──→ [Recurrent LIF] ──→ h_t
│ ↑ │
│ │ │
│ └──────────────────┘
│ │
│ 状态更新: h_t = f(h_{t-1}, s_t) │
│ │
└──────────────────────────────────────────────────────┘
2.2 循环SNN实现
class RecurrentLIFNeuron(nn.Module):
"""
循环LIF神经元
"""
def __init__(self, input_size, hidden_size, tau_mem=10.0,
tau_syn=5.0, V_th=1.0):
super().__init__()
self.hidden_size = hidden_size
self.tau_mem = tau_mem
self.tau_syn = tau_syn
self.V_th = V_th
# 权重矩阵
self.W_in = nn.Linear(input_size, hidden_size, bias=False)
self.W_rec = nn.Linear(hidden_size, hidden_size, bias=False)
# 衰减因子
self.beta_mem = torch.exp(torch.tensor(-1.0 / tau_mem))
self.beta_syn = torch.exp(torch.tensor(-1.0 / tau_syn))
# 初始化状态
self.V = None
self.I_syn = None
def init_state(self, batch_size, device):
"""初始化状态"""
self.V = torch.zeros(batch_size, self.hidden_size, device=device)
self.I_syn = torch.zeros(batch_size, self.hidden_size, device=device)
def forward(self, x_t):
"""
单步前向
x_t: (batch, input_size)
"""
# 突触电流更新
I_in = self.W_in(x_t)
I_rec = self.W_rec(self.spike if self.spike is not None
else torch.zeros_like(self.V))
self.I_syn = self.beta_syn * self.I_syn + I_in + I_rec
# 膜电位更新
dV = (-(self.V - self.V_reset) + self.I_syn) / self.tau_mem
self.V = self.V + dV
# 脉冲发放
self.spike = (self.V >= self.V_th).float()
# 重置
self.V = self.V * (1 - self.spike)
return self.spike2.3 Liquid State Machine (LSM)
LSM是一种特殊的循环SNN:
class LiquidStateMachine(nn.Module):
"""
液体状态机
- 随机连接的储备池
- 可学习的读出层
"""
def __init__(self, input_size, reservoir_size=500,
readout_size=10, tau_mem=10.0, V_th=1.0):
super().__init__()
# 随机连接的储备池
self.W_reservoir = nn.Parameter(
torch.randn(input_size, reservoir_size) * 0.1
)
# 循环连接(稀疏随机)
rec_mask = (torch.rand(reservoir_size, reservoir_size) < 0.1).float()
self.W_rec = nn.Parameter(
torch.randn(reservoir_size, reservoir_size) * 0.05 * rec_mask
)
self.reservoir = ReservoirNeuron(
reservoir_size, tau_mem=tau_mem, V_th=V_th
)
# 读出层
self.readout = nn.Linear(reservoir_size, readout_size)
def forward(self, x, num_steps=20):
"""
x: (batch, input_size, T)
"""
batch_size = x.shape[0]
self.reservoir.init_state(batch_size, x.device)
spike_history = []
for t in range(x.shape[2]):
spike = self.reservoir(x[:, :, t])
spike_history.append(spike)
# 读取储备池状态
reservoir_state = torch.stack(spike_history, dim=-1).mean(dim=-1)
# 分类
return self.readout(reservoir_state)3. 深度SNN:Spiking ResNet
3.1 深度SNN的挑战
深度SNN面临的主要问题:
- 梯度消失:深层网络难以传递梯度
- 脉冲同步:层间脉冲同步导致信息丢失
- 膜电位爆炸:深层累积导致不稳定
3.2 Spiking ResNet架构
class SpikingResidualBlock(nn.Module):
"""
脉冲残差块
"""
def __init__(self, channels, stride=1, downsample=None,
tau_mem=10.0, V_th=1.0):
super().__init__()
self.conv1 = SpikingConv2d(channels, channels, 3, stride, 1)
self.conv2 = SpikingConv2d(channels, channels, 3, 1, 1)
self.downsample = downsample
self.V_th = V_th
# 残差连接的比例(可学习)
self.residual_scale = nn.Parameter(torch.tensor(1.0))
def forward(self, x, V=None):
identity = x
# 主路径
out, V1 = self.conv1(x, V)
out, V2 = self.conv2(out, V1)
# 残差连接
if self.downsample is not None:
identity, V_ds = self.downsample(x, V)
# 脉冲残差:避免同步
out = out + self.residual_scale * identity
return out, V2
class SpikingResNet(nn.Module):
"""
Spiking ResNet用于图像分类
"""
def __init__(self, num_classes=1000):
super().__init__()
# 初始卷积
self.stem = SpikingConv2d(3, 64, 7, stride=2, padding=3)
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
# 残差层
self.layer1 = self._make_layer(64, 64, 3)
self.layer2 = self._make_layer(64, 128, 4, stride=2)
self.layer3 = self._make_layer(128, 256, 6, stride=2)
self.layer4 = self._make_layer(256, 512, 3, stride=2)
# 分类头
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(512, num_classes)
def _make_layer(self, in_channels, out_channels, blocks, stride=1):
layers = []
# 下采样块
downsample = None
if stride != 1 or in_channels != out_channels:
downsample = SpikingConv2d(in_channels, out_channels,
1, stride, bias=False)
layers.append(SpikingResidualBlock(out_channels, stride, downsample))
# 剩余块
for _ in range(1, blocks):
layers.append(SpikingResidualBlock(out_channels))
return nn.ModuleList(layers)
def forward(self, x, num_steps=4):
# 初始层
spike = x
V = None
spike, V = self.stem(spike, V)
spike = self.pool(spike)
# 残差层
for layer in [self.layer1, self.layer2, self.layer3, self.layer4]:
for block in layer:
spike, V = block(spike, V)
# 分类
rates = spike.mean(dim=(2, 3)) # 空间平均
out = self.fc(rates)
return out4. SNN-Transformer架构
4.1 挑战
将Transformer的自注意力机制适配到SNN面临:
- Softmax不可微:无法直接应用于脉冲
- 计算密集:注意力计算量大
- 时间维度:需要考虑脉冲的时序特性
4.2 Spike-driven Self-Attention
核心思想:用脉冲驱动的方式实现自注意力
class SpikeSelfAttention(nn.Module):
"""
脉冲自注意力机制
- Query、Key、Value都是脉冲序列
- 使用脉冲稀疏性减少计算
"""
def __init__(self, embed_dim, num_heads=8, V_th=1.0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
# QKV投影
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# LIF神经元用于QKV生成
self.q_lif = LIFNeuron(tau_mem=10.0, V_th=V_th)
self.k_lif = LIFNeuron(tau_mem=10.0, V_th=V_th)
self.v_lif = LIFNeuron(tau_mem=10.0, V_th=V_th)
# 输出投影
self.out_proj = nn.Linear(embed_dim, embed_dim)
def spike_attention(self, Q_spike, K_spike, V):
"""
脉冲驱动的注意力计算
"""
# QK点积(使用脉冲事件)
# 当Q有脉冲时才计算
attn_weights = torch.matmul(Q_spike, K_spike.transpose(-2, -1))
attn_weights = attn_weights * self.scale
# 脉冲计数归一化
Q_count = Q_spike.sum(dim=-2, keepdim=True) + 1e-6
attn_weights = attn_weights / Q_count
# Softmax近似:使用代理梯度
attn_weights = torch.softmax(attn_weights, dim=-1)
# 输出
out = torch.matmul(attn_weights, V)
return out
def forward(self, x, num_steps=1):
"""
x: (batch, seq_len, embed_dim) - 输入特征
"""
batch, seq_len, _ = x.shape
# 生成QKV脉冲
Q = self.q_proj(x)
K = self.k_proj(x)
V = self.v_proj(x)
# QKV膜电位和脉冲
Q_spike = self.q_lif(Q)
K_spike = self.k_lif(K)
# 注意力计算
attn_out = self.spike_attention(Q_spike, K_spike, V)
# 输出
out = self.out_proj(attn_out)
return out4.3 Spikeformer架构
class SpikingTransformerBlock(nn.Module):
"""
脉冲Transformer块
"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, V_th=1.0):
super().__init__()
# 脉冲自注意力
self.norm1 = SpikingLayerNorm(embed_dim)
self.attn = SpikeSelfAttention(embed_dim, num_heads, V_th)
# 脉冲FFN
self.norm2 = SpikingLayerNorm(embed_dim)
self.ffn = SpikingMLP(embed_dim, embed_dim * mlp_ratio, V_th)
def forward(self, x, num_steps=1):
# 自注意力 + 残差
x = x + self.attn(self.norm1(x), num_steps)
# FFN + 残差
x = x + self.ffn(self.norm2(x), num_steps)
return x
class Spikeformer(nn.Module):
"""
Spikeformer:完整的脉冲Vision Transformer
"""
def __init__(self, img_size=224, patch_size=16,
in_channels=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12):
super().__init__()
# Patch embedding
self.patch_embed = SpikingConv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
# cls token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# Transformer块
self.blocks = nn.ModuleList([
SpikingTransformerBlock(embed_dim, num_heads)
for _ in range(depth)
])
# 分类头
self.norm = SpikingLayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x, num_steps=4):
# Patch embedding
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2) # (B, N, C)
# 添加cls token
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls_token, x], dim=1)
# Transformer块
for block in self.blocks:
x = block(x, num_steps)
# 分类
x = self.norm(x)
cls_output = x[:, 0]
return self.head(cls_output)4.4 最新架构:SpikingResformer
CVPR 2024提出将ResNet和ViT架构融合到SNN中:
class SpikingResformerBlock(nn.Module):
"""
SpikingResformer块:结合残差和注意力
"""
def __init__(self, dim, num_heads=8):
super().__init__()
# 多尺度脉冲感知器
self.mspa = MultiScalePulsePerceiver(dim, num_heads)
# 脉冲FFN
self.ffn = SpikingFFN(dim)
# 残差连接
self.norm1 = SpikingLayerNorm(dim)
self.norm2 = SpikingLayerNorm(dim)
def forward(self, x):
# MSA + 残差
x = x + self.mspa(self.norm1(x))
# FFN + 残差
x = x + self.ffn(self.norm2(x))
return x5. 架构对比
| 架构 | 适用任务 | 深度 | 计算效率 | 性能 |
|---|---|---|---|---|
| ConvSNN | 图像分类、检测 | 中 | 高 | 中-高 |
| RSNN/LSM | 时序处理、语音 | 浅 | 中 | 中 |
| Spiking ResNet | 图像分类 | 深 | 中 | 高 |
| SNN-Transformer | 复杂视觉任务 | 深 | 中-低 | 高 |
6. 架构设计原则
6.1 避免同步问题
- 脉冲随机化:使用概率脉冲而非硬阈值
- 异步更新:不同层独立更新
- 膜电位噪声:添加随机扰动
6.2 平衡精度与效率
class EfficientSNNDesign:
"""高效SNN设计原则"""
# 1. 脉冲稀疏化
# - 使用较低的发放阈值
# - 添加正则化促进稀疏
# 2. 时间步数控制
# - 任务复杂度决定所需时间步
# - 避免过长的模拟时间
# 3. 混合精度
# - 早期层使用低精度
# - 后期层使用高精度7. 总结
SNN架构设计的关键要点:
- 卷积SNN:利用局部连接和稀疏激活
- 循环SNN:建模时序依赖和短期记忆
- 深度SNN:残差连接解决梯度问题
- SNN-Transformer:将自注意力机制脉冲化
不同架构适用于不同任务,选择时需要权衡性能、效率和硬件友好性。