概述
尽管 Transformer 在视觉和语言任务中占据主导地位,纯 MLP 架构在 2024-2025 年迎来了重要复兴。MLP-Mixer 证明了纯 MLP 架构可以用于视觉任务,而 Kolmogorov-Arnold Networks (KAN) 则提出了一种全新的范式:把激活函数从节点移到边上。1
MLP-Mixer:视觉任务的纯 MLP 架构
背景与动机
传统 Vision Transformer (ViT) 将图像划分为 patch,然后使用自注意力机制混合 patch 信息。MLP-Mixer 的核心问题是:能否只用 MLP 来完成同样的任务?
核心设计
MLP-Mixer 引入两种类型的 MLP:
- Token-mixing MLP:沿通道维度混合,跨 patch 共享
- Channel-mixing MLP:沿 patch 维度混合,跨通道共享
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLPMixerBlock(nn.Module):
"""
MLP-Mixer Block
两种混合操作:
1. Token-mixing: 混合不同位置的信息
2. Channel-mixing: 混合不同通道的信息
"""
def __init__(self, num_patches, d_model, expansion_factor=4):
super().__init__()
self.num_patches = num_patches
self.d_model = d_model
d_ff = d_model * expansion_factor
# Token-mixing: 作用于 patch 维度
self.norm1 = nn.LayerNorm(d_model)
self.token_mixing = nn.Sequential(
nn.Linear(num_patches, num_patches), # 无权重共享!
nn.GELU(),
nn.Linear(num_patches, num_patches),
)
# Channel-mixing: 作用于通道维度
self.norm2 = nn.LayerNorm(d_model)
self.channel_mixing = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
)
def forward(self, x):
"""
Args:
x: (batch, num_patches, d_model)
"""
# Token-mixing
residual = x
x = self.norm1(x)
# 转置以在 num_patches 维度上操作
x = x.transpose(1, 2) # (batch, d_model, num_patches)
x = self.token_mixing(x) # 线性变换跨 patch
x = x.transpose(1, 2) # (batch, num_patches, d_model)
x = x + residual
# Channel-mixing
residual = x
x = self.norm2(x)
x = self.channel_mixing(x)
x = x + residual
return x
class MLPMixer(nn.Module):
"""
完整的 MLP-Mixer 模型
"""
def __init__(self, img_size=224, patch_size=16, num_classes=1000,
depth=12, d_model=512, expansion_factor=4):
super().__init__()
# 计算 patch 数量
num_patches = (img_size // patch_size) ** 2
# Patch embedding
self.patch_embed = nn.Conv2d(
3, d_model,
kernel_size=patch_size,
stride=patch_size
) # (batch, 3, H, W) -> (batch, d_model, num_patches_h, num_patches_w)
# 可学习的类别 token 和位置编码
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
# Mixer blocks
self.blocks = nn.ModuleList([
MLPMixerBlock(num_patches, d_model, expansion_factor)
for _ in range(depth)
])
# 分类头
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, num_classes)
# 初始化
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.cls_token, std=0.02)
self.apply(self._init_module_weights)
def _init_module_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
B = x.shape[0]
# Patch embedding
x = self.patch_embed(x) # (B, 3, H, W) -> (B, d_model, h, w)
x = x.flatten(2).transpose(1, 2) # (B, num_patches, d_model)
# 添加类别 token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (B, num_patches+1, d_model)
# 添加位置编码
x = x + self.pos_embed
# Mixer blocks
for block in self.blocks:
x = block(x)
# 取类别 token 进行分类
x = self.norm(x)
cls_token = x[:, 0]
return self.head(cls_token)与 ViT 的对比
| 特性 | ViT | MLP-Mixer |
|---|---|---|
| Token-mixing | Self-attention | MLP(跨 patch) |
| Channel-mixing | FFN | MLP(跨通道) |
| 注意力复杂度 | ||
| 参数量 | 中等 | 较少 |
| 归纳偏置 | 极少 | 更少 |
性能对比
| 模型 | Top-1 Acc (ImageNet) | 参数量 | FLOPs |
|---|---|---|---|
| ViT-B/16 | 79.8% | 86M | 17.6G |
| MLP-Mixer-B/16 | 80.6% | 59M | 12.6G |
| DeiT-B | 83.1% | 86M | 17.6G |
| MLP-Mixer-L/16 | 87.4% | 208M | 44.6G |
CS-Mixer:跨尺度 MLP
核心思想
CS-Mixer(Cross-Scale Mixer)提出跨尺度混合机制,同时建模不同尺度的空间依赖。2
class CrossScaleMixing(nn.Module):
"""
跨尺度混合模块
"""
def __init__(self, d_model, num_scales=4):
super().__init__()
self.num_scales = num_scales
# 多尺度投影
self.scales = nn.ModuleList([
nn.Conv2d(d_model, d_model, kernel_size=2*s+1, padding=s)
for s in range(num_scales)
])
# 跨尺度注意力
self.cross_attention = nn.ModuleList([
nn.MultiheadAttention(d_model, num_heads=8)
for _ in range(num_scales)
])
# 融合
self.fusion = nn.Linear(d_model * num_scales, d_model)
def forward(self, x, H, W):
"""
Args:
x: (B, H*W, d_model)
H, W: 空间维度
Returns:
out: (B, H*W, d_model)
"""
B, L, D = x.shape
# 重塑为图像格式
x_img = x.view(B, H, W, D).permute(0, 3, 1, 2) # (B, D, H, W)
# 多尺度处理
multi_scale_features = []
for scale, (conv, attn) in enumerate(zip(self.scales, self.cross_attention)):
# 卷积多尺度
feat = conv(x_img) # (B, D, H, W)
# 转回序列格式
feat_seq = feat.permute(0, 2, 3, 1).reshape(B, -1, D)
# 自注意力
feat_seq = feat_seq.transpose(0, 1) # (L, B, D)
feat_seq, _ = attn(feat_seq, feat_seq, feat_seq)
feat_seq = feat_seq.transpose(0, 1) # (B, L, D)
multi_scale_features.append(feat_seq)
# 拼接并融合
fused = torch.cat(multi_scale_features, dim=-1)
out = self.fusion(fused)
return outKolmogorov-Arnold Networks (KAN)
数学基础:Kolmogorov-Arnold 表示定理
Kolmogorov-Arnold 定理(1957):
任意多元连续函数 可以表示为:
其中 和 都是一元连续函数。
核心洞见:多元函数可以分解为一元函数的组合!
KAN 的核心设计
KAN 将这一理论应用到神经网络:
- 传统 MLP:激活函数在节点上,权重是线性的
- KAN:激活函数在边上,节点只是求和
class KANLayer(nn.Module):
"""
KAN Layer
与 MLP 的核心区别:
- MLP: x -> linear -> activation -> sum
- KAN: x -> activation -> linear -> sum
激活函数是可学习的!
"""
def __init__(self, in_features, out_features, grid_size=5, spline_order=3):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
# 网格点(用于 B-样条)
h = (1.0 - 0.0) / grid_size
grid = torch.linspace(0 - h * spline_order, 1 + h * spline_order,
grid_size + 2 * spline_order + 1)
self.register_buffer('grid', grid)
# B-样条系数(可学习)
# 每个输入-输出对有 grid_size + spline_order 个系数
self.coeff = nn.Parameter(
torch.randn(out_features, in_features, grid_size + spline_order)
)
# SiLU 激活的残差连接
self.base_weight = nn.Parameter(torch.randn(out_features, in_features))
# 激活函数的初始化
self._init激活()
def _init_activation(self):
"""初始化激活函数"""
nn.init.kaiming_normal_(self.coeff, mode='fan_in', nonlinearity='linear')
nn.init.normal_(self.base_weight, std=0.1)
def b_spline_basis(self, x):
"""
计算 B-样条基函数
Args:
x: (batch, in_features) 输入,范围 [0, 1]
Returns:
bases: (batch, in_features, grid_size + spline_order)
"""
x = x.unsqueeze(-1) # (batch, in, 1)
grid = self.grid # (grid_size + 2*spline_order + 1,)
# De Boor 算法
bases = self._de_boor(x, grid, self.spline_order)
return bases # (batch, in, out_dim)
def _de_boor(self, x, grid, k):
"""
De Boor 算法计算 B-样条
"""
# 简化实现
x = x.clamp(0, 1)
# 计算基函数值
d = torch.zeros_like(x).expand(-1, -1, grid.shape[0] - 1)
# 0阶基函数
for i in range(grid.shape[0] - 1):
d[:, :, i] = ((grid[i] <= x.squeeze(-1)) &
(x.squeeze(-1) < grid[i+1])).float()
# 递归计算高阶基函数
for p in range(1, k + 1):
for i in range(grid.shape[0] - p - 1):
left = (x - grid[i]) / (grid[i+p] - grid[i] + 1e-8)
right = (grid[i+p+1] - x) / (grid[i+p+1] - grid[i+1] + 1e-8)
d[:, :, i] = left.squeeze(-1) * d[:, :, i] + right.squeeze(-1) * d[:, :, i+1]
return d[:, :, :-1] # (batch, in, grid_size + spline_order)
def forward(self, x):
"""
Args:
x: (batch, in_features) 输入,范围 [0, 1]
Returns:
y: (batch, out_features)
"""
# 确保输入在 [0, 1] 范围内
x = x.clamp(0, 1)
# B-样条激活
bases = self.b_spline_basis(x) # (batch, in, grid_size + spline)
spline_output = torch.einsum('bik,oik->bo', bases, self.coeff)
# SiLU 激活的残差
base_output = torch.einsum('bi,oi->bo', x, torch.nn.functional.silu(self.base_weight))
return spline_output + base_output
class KAN(nn.Module):
"""
完整的 Kolmogorov-Arnold Network
"""
def __init__(self, layer_dims, grid_size=5, spline_order=3):
"""
Args:
layer_dims: 每层的维度列表,如 [2, 3, 1]
"""
super().__init__()
self.layers = nn.ModuleList()
for i in range(len(layer_dims) - 1):
self.layers.append(
KANLayer(layer_dims[i], layer_dims[i+1], grid_size, spline_order)
)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return xKAN vs MLP:核心对比
| 特性 | MLP | KAN |
|---|---|---|
| 激活位置 | 节点 | 边 |
| 激活函数 | 固定(ReLU, GELU) | 可学习(样条) |
| 参数量 | ||
| 可解释性 | 低 | 高(可视化激活函数) |
| 训练速度 | 快 | 慢(需要更多 epoch) |
KAN 的可视化
def visualize_kan(model, x_range=(0, 1), save_path='kan_activation.png'):
"""
可视化 KAN 的激活函数
"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(model.layers[0].out_features,
model.layers[0].in_features,
figsize=(15, 15))
x = torch.linspace(x_range[0], x_range[1], 100)
for out_idx in range(model.layers[0].out_features):
for in_idx in range(model.layers[0].in_features):
ax = axes[out_idx, in_idx]
# 获取该连接的激活函数
with torch.no_grad():
# 构造只有一个输入非零的数据
x_single = torch.zeros(100, model.layers[0].in_features)
x_single[:, in_idx] = x
# 前向传播(只取对应输出的对应输入)
coeff = model.layers[0].coeff[out_idx, in_idx]
# 简化的可视化:直接显示系数
ax.plot(coeff.cpu().numpy())
ax.set_title(f'Out {out_idx} ← In {in_idx}')
plt.tight_layout()
plt.savefig(save_path)
plt.close()双线性 MLP:可解释性突破
核心思想
2025 年的研究发现:移除激活函数可以带来可解释性!3
双线性形式
标准 MLP:
双线性 MLP:
class BilinearMLP(nn.Module):
"""
双线性 MLP
本质上等同于 GLU(没有激活函数部分)
"""
def __init__(self, d_model, num_classes):
super().__init__()
# 双线性项:x^T W_1^T W_2 x
self.W_1 = nn.Linear(d_model, d_model)
self.W_2 = nn.Linear(d_model, num_classes)
# 线性项
self.b_1 = nn.Linear(d_model, num_classes)
# 偏置
self.b_2 = nn.Parameter(torch.zeros(num_classes))
def forward(self, x):
"""
Args:
x: (batch, d_model)
Returns:
y: (batch, num_classes)
"""
# 双线性项: x^T W_1^T W_2 x
bilinear = torch.einsum('bi,ij,jk,bk->bj',
x, self.W_1.weight, self.W_2.weight, x)
# 线性项
linear = self.b_1(x)
# 总输出
y = bilinear + linear + self.b_2
return y可解释性的来源
关键洞察:双线性 MLP 的输出可以分解为特征-特征交互的和:
- 第一项:特征交互项()
- 第二项:线性项
- 第三项:偏置
这允许直接分析哪些特征对对输出有贡献!
MLP 的缩放定律
实证发现
Bachmann 等人(2024)发现 MLP 遵循幂律缩放:4
其中 是参数数量, 是幂律指数。
MLP vs Transformer 缩放对比
| 特性 | MLP | Transformer |
|---|---|---|
| 缩放指数 | ~0.1 | ~0.05 |
| 参数量需求 | 较少 | 较多 |
| 数据效率 | 更高 | 较低 |
| 最优配置 | 更多数据,更少参数 | 更少数据,更多参数 |
实践建议
选择架构的指南
| 场景 | 推荐架构 |
|---|---|
| 图像分类(简单) | MLP-Mixer |
| 可解释性要求高 | KAN |
| 表格数据 | MLP + Embedding |
| 快速原型 | MLP |
| 高精度需求 | ViT / Transformer |
KAN 训练技巧
def train_kan(model, train_loader, epochs=100, lr=1e-3):
"""
KAN 训练技巧
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
for epoch in range(epochs):
for x, y in train_loader:
optimizer.zero_grad()
# 标准化输入到 [0, 1]
x = (x - x.min()) / (x.max() - x.min() + 1e-8)
pred = model(x)
loss = F.cross_entropy(pred, y)
loss.backward()
optimizer.step()
scheduler.step()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')参考
相关阅读
- MLP 理论基础 — 宽度-深度权衡与表达能力
- 反向传播与梯度流 — MLP 训练理论
- Vision Mamba — 视觉 SSM 架构
- ViT vs CNN — 视觉架构对比
Footnotes
-
Tolstikhin, I., et al. (2021). “MLP-Mixer: An all-MLP Architecture for Vision”. NeurIPS. ↩
-
“CS-Mixer: Cross-Scale Vision Multilayer Perceptron”. arXiv:2308.13363. ↩
-
“Bilinear MLPs enable weight-based mechanistic interpretability”. ICLR 2025. ↩
-
Bachmann, G., et al. (2024). “Scaling MLPs: A Tale of Inductive Bias”. arXiv:2306.13575. ↩