脉冲神经网络架构

概述

SNN架构设计需要解决两个核心问题:

  1. 如何利用脉冲的稀疏性实现高效计算
  2. 如何在深度网络中保持稳定的脉冲发放

本章介绍从卷积SNN到SNN-Transformer的主要架构类型。


1. 卷积SNN (ConvSNN)

1.1 架构原理

卷积SNN将传统CNN的卷积操作替换为脉冲版本:

ANN卷积:                              SNN卷积:

输入: 连续值                          输入: 脉冲序列
     ↓                                     ↓
权重: 连续值                           权重: 脉冲权重
     ↓                                     ↓
乘累加: 密集计算                       事件驱动: 稀疏计算
     ↓                                     ↓
激活: ReLU/sigmoid                    膜电位更新
                                         ↓
                                        脉冲发放

1.2 脉冲卷积实现

class SpikingConv2d(nn.Module):
    """
    脉冲卷积层
    """
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, bias=False,
                 tau_mem=10.0, V_th=1.0):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 
                             kernel_size, stride, padding, bias=bias)
        self.tau_mem = tau_mem
        self.V_th = V_th
        self.V_reset = 0.0
        
        # 衰减因子
        self.beta = torch.exp(torch.tensor(-1.0 / tau_mem))
    
    def forward(self, x_spike, V_mem=None):
        """
        x_spike: (batch, C, H, W) - 当前时刻的输入脉冲
        V_mem: 膜电位状态
        """
        batch, c, h, w = x_spike.shape
        
        if V_mem is None:
            V_mem = torch.zeros(batch, self.conv.out_channels,
                               h_out, w_out, device=x_spike.device)
        
        # 突触电流:卷积操作
        I_syn = self.conv(x_spike)
        
        # 膜电位更新
        V_mem = self.beta * V_mem + I_syn
        
        # 脉冲发放
        spike = (V_mem >= self.V_th).float()
        
        # 重置
        V_mem = V_mem * (1 - spike) + self.V_reset * spike
        
        return spike, V_mem

1.3 卷积SNN网络结构

class ConvSNN(nn.Module):
    """
    卷积SNN用于图像分类
    """
    def __init__(self, num_classes=10):
        super().__init__()
        
        # 特征提取
        self.conv1 = SpikingConv2d(1, 64, 3, padding=1, tau_mem=10.0)
        self.conv2 = SpikingConv2d(64, 128, 3, padding=1, tau_mem=10.0)
        self.conv3 = SpikingConv2d(128, 256, 3, padding=1, tau_mem=10.0)
        
        # 池化
        self.pool = nn.MaxPool2d(2, 2)
        
        # 全连接层
        self.fc = nn.Linear(256 * 4 * 4, num_classes)
        
        # 可学习的膜电位(用于最后一层)
        self.fc_lif = LIFNeuron(tau_mem=10.0)
    
    def forward(self, x, num_steps=4):
        """
        x: (batch, 1, 32, 32) - 归一化的图像
        """
        batch_size = x.shape[0]
        
        # 初始化膜电位
        V1 = V2 = V3 = None
        spike_counts = torch.zeros(batch_size, 256 * 4 * 4, device=x.device)
        
        for t in range(num_steps):
            # 广播静态输入为时序脉冲
            x_t = x if t == 0 else torch.zeros_like(x)
            # 或者使用泊松编码
            # x_t = (torch.rand_like(x) < x).float()
            
            # 卷积层1
            s1, V1 = self.conv1(x_t, V1)
            s1 = self.pool(s1)
            
            # 卷积层2
            s2, V2 = self.conv2(s1, V2)
            s2 = self.pool(s2)
            
            # 卷积层3
            s3, V3 = self.conv3(s2, V3)
            s3 = self.pool(s3)
            
            # 累计脉冲
            spike_counts += s3.view(batch_size, -1)
        
        # 发放率
        rates = spike_counts / num_steps
        
        # 分类
        out = self.fc(rates)
        
        return out

2. 循环SNN (RSNN)

2.1 循环连接的优势

循环连接让SNN能够:

  • 建模时间依赖关系
  • 增加网络容量而不增加参数
  • 模拟工作记忆
┌──────────────────────────────────────────────────────┐
│                    循环SNN结构                         │
├──────────────────────────────────────────────────────┤
│                                                       │
│     x_t ──→ [Conv] ──→ s_t ──→ [Recurrent LIF] ──→ h_t
│                                    ↑                  │
│                                    │                  │
│                                    └──────────────────┘
│                                                       │
│     状态更新: h_t = f(h_{t-1}, s_t)                  │
│                                                       │
└──────────────────────────────────────────────────────┘

2.2 循环SNN实现

class RecurrentLIFNeuron(nn.Module):
    """
    循环LIF神经元
    """
    def __init__(self, input_size, hidden_size, tau_mem=10.0, 
                 tau_syn=5.0, V_th=1.0):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.tau_mem = tau_mem
        self.tau_syn = tau_syn
        self.V_th = V_th
        
        # 权重矩阵
        self.W_in = nn.Linear(input_size, hidden_size, bias=False)
        self.W_rec = nn.Linear(hidden_size, hidden_size, bias=False)
        
        # 衰减因子
        self.beta_mem = torch.exp(torch.tensor(-1.0 / tau_mem))
        self.beta_syn = torch.exp(torch.tensor(-1.0 / tau_syn))
        
        # 初始化状态
        self.V = None
        self.I_syn = None
    
    def init_state(self, batch_size, device):
        """初始化状态"""
        self.V = torch.zeros(batch_size, self.hidden_size, device=device)
        self.I_syn = torch.zeros(batch_size, self.hidden_size, device=device)
    
    def forward(self, x_t):
        """
        单步前向
        x_t: (batch, input_size)
        """
        # 突触电流更新
        I_in = self.W_in(x_t)
        I_rec = self.W_rec(self.spike if self.spike is not None 
                          else torch.zeros_like(self.V))
        self.I_syn = self.beta_syn * self.I_syn + I_in + I_rec
        
        # 膜电位更新
        dV = (-(self.V - self.V_reset) + self.I_syn) / self.tau_mem
        self.V = self.V + dV
        
        # 脉冲发放
        self.spike = (self.V >= self.V_th).float()
        
        # 重置
        self.V = self.V * (1 - self.spike)
        
        return self.spike

2.3 Liquid State Machine (LSM)

LSM是一种特殊的循环SNN:

class LiquidStateMachine(nn.Module):
    """
    液体状态机
    - 随机连接的储备池
    - 可学习的读出层
    """
    def __init__(self, input_size, reservoir_size=500, 
                 readout_size=10, tau_mem=10.0, V_th=1.0):
        super().__init__()
        
        # 随机连接的储备池
        self.W_reservoir = nn.Parameter(
            torch.randn(input_size, reservoir_size) * 0.1
        )
        
        # 循环连接(稀疏随机)
        rec_mask = (torch.rand(reservoir_size, reservoir_size) < 0.1).float()
        self.W_rec = nn.Parameter(
            torch.randn(reservoir_size, reservoir_size) * 0.05 * rec_mask
        )
        
        self.reservoir = ReservoirNeuron(
            reservoir_size, tau_mem=tau_mem, V_th=V_th
        )
        
        # 读出层
        self.readout = nn.Linear(reservoir_size, readout_size)
    
    def forward(self, x, num_steps=20):
        """
        x: (batch, input_size, T)
        """
        batch_size = x.shape[0]
        self.reservoir.init_state(batch_size, x.device)
        
        spike_history = []
        
        for t in range(x.shape[2]):
            spike = self.reservoir(x[:, :, t])
            spike_history.append(spike)
        
        # 读取储备池状态
        reservoir_state = torch.stack(spike_history, dim=-1).mean(dim=-1)
        
        # 分类
        return self.readout(reservoir_state)

3. 深度SNN:Spiking ResNet

3.1 深度SNN的挑战

深度SNN面临的主要问题:

  1. 梯度消失:深层网络难以传递梯度
  2. 脉冲同步:层间脉冲同步导致信息丢失
  3. 膜电位爆炸:深层累积导致不稳定

3.2 Spiking ResNet架构

class SpikingResidualBlock(nn.Module):
    """
    脉冲残差块
    """
    def __init__(self, channels, stride=1, downsample=None,
                 tau_mem=10.0, V_th=1.0):
        super().__init__()
        
        self.conv1 = SpikingConv2d(channels, channels, 3, stride, 1)
        self.conv2 = SpikingConv2d(channels, channels, 3, 1, 1)
        
        self.downsample = downsample
        self.V_th = V_th
        
        # 残差连接的比例(可学习)
        self.residual_scale = nn.Parameter(torch.tensor(1.0))
    
    def forward(self, x, V=None):
        identity = x
        
        # 主路径
        out, V1 = self.conv1(x, V)
        out, V2 = self.conv2(out, V1)
        
        # 残差连接
        if self.downsample is not None:
            identity, V_ds = self.downsample(x, V)
        
        # 脉冲残差:避免同步
        out = out + self.residual_scale * identity
        
        return out, V2
 
 
class SpikingResNet(nn.Module):
    """
    Spiking ResNet用于图像分类
    """
    def __init__(self, num_classes=1000):
        super().__init__()
        
        # 初始卷积
        self.stem = SpikingConv2d(3, 64, 7, stride=2, padding=3)
        self.pool = nn.MaxPool2d(3, stride=2, padding=1)
        
        # 残差层
        self.layer1 = self._make_layer(64, 64, 3)
        self.layer2 = self._make_layer(64, 128, 4, stride=2)
        self.layer3 = self._make_layer(128, 256, 6, stride=2)
        self.layer4 = self._make_layer(256, 512, 3, stride=2)
        
        # 分类头
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_classes)
    
    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        layers = []
        
        # 下采样块
        downsample = None
        if stride != 1 or in_channels != out_channels:
            downsample = SpikingConv2d(in_channels, out_channels, 
                                     1, stride, bias=False)
        
        layers.append(SpikingResidualBlock(out_channels, stride, downsample))
        
        # 剩余块
        for _ in range(1, blocks):
            layers.append(SpikingResidualBlock(out_channels))
        
        return nn.ModuleList(layers)
    
    def forward(self, x, num_steps=4):
        # 初始层
        spike = x
        V = None
        
        spike, V = self.stem(spike, V)
        spike = self.pool(spike)
        
        # 残差层
        for layer in [self.layer1, self.layer2, self.layer3, self.layer4]:
            for block in layer:
                spike, V = block(spike, V)
        
        # 分类
        rates = spike.mean(dim=(2, 3))  # 空间平均
        out = self.fc(rates)
        
        return out

4. SNN-Transformer架构

4.1 挑战

将Transformer的自注意力机制适配到SNN面临:

  1. Softmax不可微:无法直接应用于脉冲
  2. 计算密集:注意力计算量大
  3. 时间维度:需要考虑脉冲的时序特性

4.2 Spike-driven Self-Attention

核心思想:用脉冲驱动的方式实现自注意力

class SpikeSelfAttention(nn.Module):
    """
    脉冲自注意力机制
    - Query、Key、Value都是脉冲序列
    - 使用脉冲稀疏性减少计算
    """
    def __init__(self, embed_dim, num_heads=8, V_th=1.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        # QKV投影
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        
        # LIF神经元用于QKV生成
        self.q_lif = LIFNeuron(tau_mem=10.0, V_th=V_th)
        self.k_lif = LIFNeuron(tau_mem=10.0, V_th=V_th)
        self.v_lif = LIFNeuron(tau_mem=10.0, V_th=V_th)
        
        # 输出投影
        self.out_proj = nn.Linear(embed_dim, embed_dim)
    
    def spike_attention(self, Q_spike, K_spike, V):
        """
        脉冲驱动的注意力计算
        """
        # QK点积(使用脉冲事件)
        # 当Q有脉冲时才计算
        attn_weights = torch.matmul(Q_spike, K_spike.transpose(-2, -1))
        attn_weights = attn_weights * self.scale
        
        # 脉冲计数归一化
        Q_count = Q_spike.sum(dim=-2, keepdim=True) + 1e-6
        attn_weights = attn_weights / Q_count
        
        # Softmax近似:使用代理梯度
        attn_weights = torch.softmax(attn_weights, dim=-1)
        
        # 输出
        out = torch.matmul(attn_weights, V)
        
        return out
    
    def forward(self, x, num_steps=1):
        """
        x: (batch, seq_len, embed_dim) - 输入特征
        """
        batch, seq_len, _ = x.shape
        
        # 生成QKV脉冲
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)
        
        # QKV膜电位和脉冲
        Q_spike = self.q_lif(Q)
        K_spike = self.k_lif(K)
        
        # 注意力计算
        attn_out = self.spike_attention(Q_spike, K_spike, V)
        
        # 输出
        out = self.out_proj(attn_out)
        
        return out

4.3 Spikeformer架构

class SpikingTransformerBlock(nn.Module):
    """
    脉冲Transformer块
    """
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, V_th=1.0):
        super().__init__()
        
        # 脉冲自注意力
        self.norm1 = SpikingLayerNorm(embed_dim)
        self.attn = SpikeSelfAttention(embed_dim, num_heads, V_th)
        
        # 脉冲FFN
        self.norm2 = SpikingLayerNorm(embed_dim)
        self.ffn = SpikingMLP(embed_dim, embed_dim * mlp_ratio, V_th)
    
    def forward(self, x, num_steps=1):
        # 自注意力 + 残差
        x = x + self.attn(self.norm1(x), num_steps)
        
        # FFN + 残差
        x = x + self.ffn(self.norm2(x), num_steps)
        
        return x
 
 
class Spikeformer(nn.Module):
    """
    Spikeformer:完整的脉冲Vision Transformer
    """
    def __init__(self, img_size=224, patch_size=16, 
                 in_channels=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12):
        super().__init__()
        
        # Patch embedding
        self.patch_embed = SpikingConv2d(
            in_channels, embed_dim, 
            kernel_size=patch_size, stride=patch_size
        )
        
        # cls token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Transformer块
        self.blocks = nn.ModuleList([
            SpikingTransformerBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])
        
        # 分类头
        self.norm = SpikingLayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
    
    def forward(self, x, num_steps=4):
        # Patch embedding
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)  # (B, N, C)
        
        # 添加cls token
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        
        # Transformer块
        for block in self.blocks:
            x = block(x, num_steps)
        
        # 分类
        x = self.norm(x)
        cls_output = x[:, 0]
        
        return self.head(cls_output)

4.4 最新架构:SpikingResformer

CVPR 2024提出将ResNet和ViT架构融合到SNN中:

class SpikingResformerBlock(nn.Module):
    """
    SpikingResformer块:结合残差和注意力
    """
    def __init__(self, dim, num_heads=8):
        super().__init__()
        
        # 多尺度脉冲感知器
        self.mspa = MultiScalePulsePerceiver(dim, num_heads)
        
        # 脉冲FFN
        self.ffn = SpikingFFN(dim)
        
        # 残差连接
        self.norm1 = SpikingLayerNorm(dim)
        self.norm2 = SpikingLayerNorm(dim)
    
    def forward(self, x):
        # MSA + 残差
        x = x + self.mspa(self.norm1(x))
        
        # FFN + 残差
        x = x + self.ffn(self.norm2(x))
        
        return x

5. 架构对比

架构适用任务深度计算效率性能
ConvSNN图像分类、检测中-高
RSNN/LSM时序处理、语音
Spiking ResNet图像分类
SNN-Transformer复杂视觉任务中-低

6. 架构设计原则

6.1 避免同步问题

  1. 脉冲随机化:使用概率脉冲而非硬阈值
  2. 异步更新:不同层独立更新
  3. 膜电位噪声:添加随机扰动

6.2 平衡精度与效率

class EfficientSNNDesign:
    """高效SNN设计原则"""
    
    # 1. 脉冲稀疏化
    # - 使用较低的发放阈值
    # - 添加正则化促进稀疏
    
    # 2. 时间步数控制
    # - 任务复杂度决定所需时间步
    # - 避免过长的模拟时间
    
    # 3. 混合精度
    # - 早期层使用低精度
    # - 后期层使用高精度

7. 总结

SNN架构设计的关键要点:

  1. 卷积SNN:利用局部连接和稀疏激活
  2. 循环SNN:建模时序依赖和短期记忆
  3. 深度SNN:残差连接解决梯度问题
  4. SNN-Transformer:将自注意力机制脉冲化

不同架构适用于不同任务,选择时需要权衡性能、效率和硬件友好性。


参考