概述

尽管 Transformer 在视觉和语言任务中占据主导地位,纯 MLP 架构在 2024-2025 年迎来了重要复兴。MLP-Mixer 证明了纯 MLP 架构可以用于视觉任务,而 Kolmogorov-Arnold Networks (KAN) 则提出了一种全新的范式:把激活函数从节点移到边上。1


MLP-Mixer:视觉任务的纯 MLP 架构

背景与动机

传统 Vision Transformer (ViT) 将图像划分为 patch,然后使用自注意力机制混合 patch 信息。MLP-Mixer 的核心问题是:能否只用 MLP 来完成同样的任务?

核心设计

MLP-Mixer 引入两种类型的 MLP:

  1. Token-mixing MLP:沿通道维度混合,跨 patch 共享
  2. Channel-mixing MLP:沿 patch 维度混合,跨通道共享
import torch
import torch.nn as nn
import torch.nn.functional as F
 
class MLPMixerBlock(nn.Module):
    """
    MLP-Mixer Block
    
    两种混合操作:
    1. Token-mixing: 混合不同位置的信息
    2. Channel-mixing: 混合不同通道的信息
    """
    def __init__(self, num_patches, d_model, expansion_factor=4):
        super().__init__()
        self.num_patches = num_patches
        self.d_model = d_model
        d_ff = d_model * expansion_factor
        
        # Token-mixing: 作用于 patch 维度
        self.norm1 = nn.LayerNorm(d_model)
        self.token_mixing = nn.Sequential(
            nn.Linear(num_patches, num_patches),  # 无权重共享!
            nn.GELU(),
            nn.Linear(num_patches, num_patches),
        )
        
        # Channel-mixing: 作用于通道维度
        self.norm2 = nn.LayerNorm(d_model)
        self.channel_mixing = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
        )
    
    def forward(self, x):
        """
        Args:
            x: (batch, num_patches, d_model)
        """
        # Token-mixing
        residual = x
        x = self.norm1(x)
        # 转置以在 num_patches 维度上操作
        x = x.transpose(1, 2)  # (batch, d_model, num_patches)
        x = self.token_mixing(x)  # 线性变换跨 patch
        x = x.transpose(1, 2)  # (batch, num_patches, d_model)
        x = x + residual
        
        # Channel-mixing
        residual = x
        x = self.norm2(x)
        x = self.channel_mixing(x)
        x = x + residual
        
        return x
 
 
class MLPMixer(nn.Module):
    """
    完整的 MLP-Mixer 模型
    """
    def __init__(self, img_size=224, patch_size=16, num_classes=1000, 
                 depth=12, d_model=512, expansion_factor=4):
        super().__init__()
        
        # 计算 patch 数量
        num_patches = (img_size // patch_size) ** 2
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(
            3, d_model, 
            kernel_size=patch_size, 
            stride=patch_size
        )  # (batch, 3, H, W) -> (batch, d_model, num_patches_h, num_patches_w)
        
        # 可学习的类别 token 和位置编码
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
        
        # Mixer blocks
        self.blocks = nn.ModuleList([
            MLPMixerBlock(num_patches, d_model, expansion_factor)
            for _ in range(depth)
        ])
        
        # 分类头
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, num_classes)
        
        # 初始化
        self._init_weights()
    
    def _init_weights(self):
        nn.init.normal_(self.pos_embed, std=0.02)
        nn.init.normal_(self.cls_token, std=0.02)
        self.apply(self._init_module_weights)
    
    def _init_module_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, 3, H, W) -> (B, d_model, h, w)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, d_model)
        
        # 添加类别 token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, num_patches+1, d_model)
        
        # 添加位置编码
        x = x + self.pos_embed
        
        # Mixer blocks
        for block in self.blocks:
            x = block(x)
        
        # 取类别 token 进行分类
        x = self.norm(x)
        cls_token = x[:, 0]
        
        return self.head(cls_token)

与 ViT 的对比

特性ViTMLP-Mixer
Token-mixingSelf-attentionMLP(跨 patch)
Channel-mixingFFNMLP(跨通道)
注意力复杂度
参数量中等较少
归纳偏置极少更少

性能对比

模型Top-1 Acc (ImageNet)参数量FLOPs
ViT-B/1679.8%86M17.6G
MLP-Mixer-B/1680.6%59M12.6G
DeiT-B83.1%86M17.6G
MLP-Mixer-L/1687.4%208M44.6G

CS-Mixer:跨尺度 MLP

核心思想

CS-Mixer(Cross-Scale Mixer)提出跨尺度混合机制,同时建模不同尺度的空间依赖。2

class CrossScaleMixing(nn.Module):
    """
    跨尺度混合模块
    """
    def __init__(self, d_model, num_scales=4):
        super().__init__()
        self.num_scales = num_scales
        
        # 多尺度投影
        self.scales = nn.ModuleList([
            nn.Conv2d(d_model, d_model, kernel_size=2*s+1, padding=s)
            for s in range(num_scales)
        ])
        
        # 跨尺度注意力
        self.cross_attention = nn.ModuleList([
            nn.MultiheadAttention(d_model, num_heads=8)
            for _ in range(num_scales)
        ])
        
        # 融合
        self.fusion = nn.Linear(d_model * num_scales, d_model)
    
    def forward(self, x, H, W):
        """
        Args:
            x: (B, H*W, d_model)
            H, W: 空间维度
        Returns:
            out: (B, H*W, d_model)
        """
        B, L, D = x.shape
        
        # 重塑为图像格式
        x_img = x.view(B, H, W, D).permute(0, 3, 1, 2)  # (B, D, H, W)
        
        # 多尺度处理
        multi_scale_features = []
        for scale, (conv, attn) in enumerate(zip(self.scales, self.cross_attention)):
            # 卷积多尺度
            feat = conv(x_img)  # (B, D, H, W)
            
            # 转回序列格式
            feat_seq = feat.permute(0, 2, 3, 1).reshape(B, -1, D)
            
            # 自注意力
            feat_seq = feat_seq.transpose(0, 1)  # (L, B, D)
            feat_seq, _ = attn(feat_seq, feat_seq, feat_seq)
            feat_seq = feat_seq.transpose(0, 1)  # (B, L, D)
            
            multi_scale_features.append(feat_seq)
        
        # 拼接并融合
        fused = torch.cat(multi_scale_features, dim=-1)
        out = self.fusion(fused)
        
        return out

Kolmogorov-Arnold Networks (KAN)

数学基础:Kolmogorov-Arnold 表示定理

Kolmogorov-Arnold 定理(1957)

任意多元连续函数 可以表示为:

其中 都是一元连续函数。

核心洞见:多元函数可以分解为一元函数的组合!

KAN 的核心设计

KAN 将这一理论应用到神经网络:

  • 传统 MLP:激活函数在节点上,权重是线性的
  • KAN:激活函数在边上,节点只是求和
class KANLayer(nn.Module):
    """
    KAN Layer
    
    与 MLP 的核心区别:
    - MLP: x -> linear -> activation -> sum
    - KAN: x -> activation -> linear -> sum
    
    激活函数是可学习的!
    """
    def __init__(self, in_features, out_features, grid_size=5, spline_order=3):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order
        
        # 网格点(用于 B-样条)
        h = (1.0 - 0.0) / grid_size
        grid = torch.linspace(0 - h * spline_order, 1 + h * spline_order, 
                             grid_size + 2 * spline_order + 1)
        self.register_buffer('grid', grid)
        
        # B-样条系数(可学习)
        # 每个输入-输出对有 grid_size + spline_order 个系数
        self.coeff = nn.Parameter(
            torch.randn(out_features, in_features, grid_size + spline_order)
        )
        
        # SiLU 激活的残差连接
        self.base_weight = nn.Parameter(torch.randn(out_features, in_features))
        
        # 激活函数的初始化
        self._init激活()
    
    def _init_activation(self):
        """初始化激活函数"""
        nn.init.kaiming_normal_(self.coeff, mode='fan_in', nonlinearity='linear')
        nn.init.normal_(self.base_weight, std=0.1)
    
    def b_spline_basis(self, x):
        """
        计算 B-样条基函数
        
        Args:
            x: (batch, in_features) 输入,范围 [0, 1]
        Returns:
            bases: (batch, in_features, grid_size + spline_order)
        """
        x = x.unsqueeze(-1)  # (batch, in, 1)
        grid = self.grid  # (grid_size + 2*spline_order + 1,)
        
        # De Boor 算法
        bases = self._de_boor(x, grid, self.spline_order)
        
        return bases  # (batch, in, out_dim)
    
    def _de_boor(self, x, grid, k):
        """
        De Boor 算法计算 B-样条
        """
        # 简化实现
        x = x.clamp(0, 1)
        
        # 计算基函数值
        d = torch.zeros_like(x).expand(-1, -1, grid.shape[0] - 1)
        
        # 0阶基函数
        for i in range(grid.shape[0] - 1):
            d[:, :, i] = ((grid[i] <= x.squeeze(-1)) & 
                         (x.squeeze(-1) < grid[i+1])).float()
        
        # 递归计算高阶基函数
        for p in range(1, k + 1):
            for i in range(grid.shape[0] - p - 1):
                left = (x - grid[i]) / (grid[i+p] - grid[i] + 1e-8)
                right = (grid[i+p+1] - x) / (grid[i+p+1] - grid[i+1] + 1e-8)
                d[:, :, i] = left.squeeze(-1) * d[:, :, i] + right.squeeze(-1) * d[:, :, i+1]
        
        return d[:, :, :-1]  # (batch, in, grid_size + spline_order)
    
    def forward(self, x):
        """
        Args:
            x: (batch, in_features) 输入,范围 [0, 1]
        Returns:
            y: (batch, out_features)
        """
        # 确保输入在 [0, 1] 范围内
        x = x.clamp(0, 1)
        
        # B-样条激活
        bases = self.b_spline_basis(x)  # (batch, in, grid_size + spline)
        spline_output = torch.einsum('bik,oik->bo', bases, self.coeff)
        
        # SiLU 激活的残差
        base_output = torch.einsum('bi,oi->bo', x, torch.nn.functional.silu(self.base_weight))
        
        return spline_output + base_output
 
 
class KAN(nn.Module):
    """
    完整的 Kolmogorov-Arnold Network
    """
    def __init__(self, layer_dims, grid_size=5, spline_order=3):
        """
        Args:
            layer_dims: 每层的维度列表,如 [2, 3, 1]
        """
        super().__init__()
        self.layers = nn.ModuleList()
        
        for i in range(len(layer_dims) - 1):
            self.layers.append(
                KANLayer(layer_dims[i], layer_dims[i+1], grid_size, spline_order)
            )
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

KAN vs MLP:核心对比

特性MLPKAN
激活位置节点
激活函数固定(ReLU, GELU)可学习(样条)
参数量
可解释性高(可视化激活函数)
训练速度慢(需要更多 epoch)

KAN 的可视化

def visualize_kan(model, x_range=(0, 1), save_path='kan_activation.png'):
    """
    可视化 KAN 的激活函数
    """
    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(model.layers[0].out_features, 
                             model.layers[0].in_features,
                             figsize=(15, 15))
    
    x = torch.linspace(x_range[0], x_range[1], 100)
    
    for out_idx in range(model.layers[0].out_features):
        for in_idx in range(model.layers[0].in_features):
            ax = axes[out_idx, in_idx]
            
            # 获取该连接的激活函数
            with torch.no_grad():
                # 构造只有一个输入非零的数据
                x_single = torch.zeros(100, model.layers[0].in_features)
                x_single[:, in_idx] = x
                
                # 前向传播(只取对应输出的对应输入)
                coeff = model.layers[0].coeff[out_idx, in_idx]
                # 简化的可视化:直接显示系数
                ax.plot(coeff.cpu().numpy())
            
            ax.set_title(f'Out {out_idx} ← In {in_idx}')
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

双线性 MLP:可解释性突破

核心思想

2025 年的研究发现:移除激活函数可以带来可解释性3

双线性形式

标准 MLP

双线性 MLP

class BilinearMLP(nn.Module):
    """
    双线性 MLP
    
    本质上等同于 GLU(没有激活函数部分)
    """
    def __init__(self, d_model, num_classes):
        super().__init__()
        
        # 双线性项:x^T W_1^T W_2 x
        self.W_1 = nn.Linear(d_model, d_model)
        self.W_2 = nn.Linear(d_model, num_classes)
        
        # 线性项
        self.b_1 = nn.Linear(d_model, num_classes)
        
        # 偏置
        self.b_2 = nn.Parameter(torch.zeros(num_classes))
    
    def forward(self, x):
        """
        Args:
            x: (batch, d_model)
        Returns:
            y: (batch, num_classes)
        """
        # 双线性项: x^T W_1^T W_2 x
        bilinear = torch.einsum('bi,ij,jk,bk->bj', 
                                x, self.W_1.weight, self.W_2.weight, x)
        
        # 线性项
        linear = self.b_1(x)
        
        # 总输出
        y = bilinear + linear + self.b_2
        
        return y

可解释性的来源

关键洞察:双线性 MLP 的输出可以分解为特征-特征交互的和:

  • 第一项:特征交互项(
  • 第二项:线性项
  • 第三项:偏置

这允许直接分析哪些特征对对输出有贡献!


MLP 的缩放定律

实证发现

Bachmann 等人(2024)发现 MLP 遵循幂律缩放4

其中 是参数数量, 是幂律指数。

MLP vs Transformer 缩放对比

特性MLPTransformer
缩放指数 ~0.1~0.05
参数量需求较少较多
数据效率更高较低
最优配置更多数据,更少参数更少数据,更多参数

实践建议

选择架构的指南

场景推荐架构
图像分类(简单)MLP-Mixer
可解释性要求高KAN
表格数据MLP + Embedding
快速原型MLP
高精度需求ViT / Transformer

KAN 训练技巧

def train_kan(model, train_loader, epochs=100, lr=1e-3):
    """
    KAN 训练技巧
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    for epoch in range(epochs):
        for x, y in train_loader:
            optimizer.zero_grad()
            
            # 标准化输入到 [0, 1]
            x = (x - x.min()) / (x.max() - x.min() + 1e-8)
            
            pred = model(x)
            loss = F.cross_entropy(pred, y)
            
            loss.backward()
            optimizer.step()
        
        scheduler.step()
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item():.4f}')

参考


相关阅读

Footnotes

  1. Tolstikhin, I., et al. (2021). “MLP-Mixer: An all-MLP Architecture for Vision”. NeurIPS.

  2. “CS-Mixer: Cross-Scale Vision Multilayer Perceptron”. arXiv:2308.13363.

  3. “Bilinear MLPs enable weight-based mechanistic interpretability”. ICLR 2025.

  4. Bachmann, G., et al. (2024). “Scaling MLPs: A Tale of Inductive Bias”. arXiv:2306.13575.