概述
自2024年4月KAN(Kolmogorov-Arnold Networks)论文发表以来,研究社区迅速涌现出多种变体,旨在解决原始KAN的局限性或扩展其应用场景。1本文系统综述主要KAN变体,涵盖Wav-KAN(小波激活)、Graph KAN(图结构数据)、FastKAN(计算效率)、CKAN(Chebyshev多项式)以及其他重要变体。
1. 变体分类体系
1.1 按设计目标分类
| 类别 | 变体 | 核心目标 |
|---|---|---|
| 可解释性增强 | Wav-KAN, Spline-KAN | 更清晰的激活函数可视化 |
| 效率优化 | FastKAN, CKAN, GPU-KAN | 降低计算复杂度 |
| 结构扩展 | Graph KAN, RNN-KAN | 适配图结构/RNN |
| 精度提升 | Kolmogorov-GAN, Ada-KAN | 改善拟合能力 |
1.2 变体时间线
2024.04 KAN (MIT)
2024.05 Wav-KAN (小波激活)
2024.06 FastKAN (计算效率)
2024.07 Graph KAN (图数据)
2024.08 KAN 2.0 (科学应用)
2024.09 CKAN (Chebyshev多项式)
2024.10 Ada-KAN (自适应)
2024.11 RNN-KAN (序列建模)
2. Wav-KAN:小波激活函数
2.1 核心思想
Wav-KAN(Wavelet Kolmogorov-Arnold Networks)用小波函数替代B-样条作为激活函数,提升可解释性和多尺度分析能力。2
为什么使用小波?
- 多尺度分析:小波天然具有多尺度特性
- 稀疏表示:小波变换产生稀疏表示
- 时频局部化:同时具有时间和频率局部化能力
- 可解释性:小波基函数具有明确的物理意义
2.2 小波理论基础
Haar小波
最简单的正交小波:
Daubechies小波
更平滑的小波,具有紧支撑性。阶Daubechies小波记为 。
import torch
import torch.nn as nn
import numpy as np
class WaveletBasis(nn.Module):
"""
小波基函数
支持多种小波类型
"""
def __init__(self, wavelet_type='db4', num_scales=4):
super().__init__()
self.wavelet_type = wavelet_type
self.num_scales = num_scales
# 预计算小波系数
self.register_buffer('wavelet_coeffs', self._get_wavelet_coeffs())
def _get_wavelet_coeffs(self):
"""获取小波系数"""
if self.wavelet_type == 'haar':
# Haar小波
return torch.tensor([1.0, -1.0])
elif self.wavelet_type == 'db4':
# Daubechies-4小波(4阶消失矩)
return torch.tensor([0.4830, 0.8365, 0.2241, -0.1294])
elif self.wavelet_type == 'sym4':
# Symlet-4小波
return torch.tensor([0.0228, -0.0633, -0..ext-0.0651, 0.1830,
0.9111, 0.3514])
else:
raise ValueError(f"Unknown wavelet type: {self.wavelet_type}")
def forward(self, x):
"""
计算小波基函数值
"""
coeffs = self.wavelet_coeffs
# 简化的离散小波变换
output = torch.zeros_like(x)
for i, c in enumerate(coeffs):
shifted = torch.roll(x, shifts=i, dims=-1)
output += c * shifted
return output
class WavKANLayer(nn.Module):
"""
Wav-KAN Layer
使用小波激活函数的 KAN Layer
"""
def __init__(self, in_features, out_features,
num_scales=4, wavelet_type='db4'):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.num_scales = num_scales
self.wavelet_type = wavelet_type
# 多尺度小波基
self.wavelet_bases = nn.ModuleList([
WaveletBasis(wavelet_type, scale+1)
for scale in range(num_scales)
])
# 小波系数(可学习)
self.wavelet_coeffs = nn.Parameter(
torch.randn(out_features, in_features, num_scales)
)
# 尺度函数(近似部分)
self.scale_coeffs = nn.Parameter(
torch.randn(out_features, in_features)
)
# 初始化
self._init_weights()
def _init_weights(self):
nn.init.kaiming_normal_(self.wavelet_coeffs)
nn.init.normal_(self.scale_coeffs, std=0.1)
def forward(self, x):
"""
Args:
x: (batch, in_features)
Returns:
y: (batch, out_features)
"""
# 确保输入在 [0, 1]
x = x.clamp(0, 1)
# 多尺度小波激活
wavelet_outputs = []
for scale in range(self.num_scales):
# 计算该尺度的小波变换
psi = self.wavelet_bases[scale](x) # (batch, in)
# 加权组合
weighted = psi.unsqueeze(1) * self.wavelet_coeffs[:, :, scale] # (batch, out, in)
wavelet_outputs.append(weighted.sum(dim=-1)) # (batch, out)
# 尺度函数(低频部分)
scale_output = torch.einsum('bi,oi->bo', x, self.scale_coeffs)
# 融合多尺度输出
wavelet_stack = torch.stack(wavelet_outputs, dim=-1) # (batch, out, scales)
wavelet_sum = wavelet_stack.sum(dim=-1) # (batch, out)
return scale_output + wavelet_sum
class WavKAN(nn.Module):
"""
完整的 Wav-KAN 模型
"""
def __init__(self, layer_dims, num_scales=4, wavelet_type='db4'):
super().__init__()
self.layers = nn.ModuleList()
for i in range(len(layer_dims) - 1):
self.layers.append(
WavKANLayer(layer_dims[i], layer_dims[i+1],
num_scales, wavelet_type)
)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x2.3 Wav-KAN vs B-Spline KAN
| 特性 | B-Spline KAN | Wav-KAN |
|---|---|---|
| 基函数 | B-样条 | 小波 |
| 多尺度分析 | 否 | 是 |
| 稀疏表示 | 中 | 高 |
| 计算复杂度 | ||
| 可解释性 | 中 | 高(小波有物理意义) |
| 适用场景 | 一般 | 信号/图像 |
3. FastKAN:高效计算变体
3.1 动机与挑战
原始KAN的主要计算瓶颈在于B-样条激活函数的计算。FastKAN通过近似和优化来解决这个问题。3
3.2 RBF-Net 近似
FastKAN使用**径向基函数(RBF)**近似B-样条:
class FastKANLayer(nn.Module):
"""
Fast-KAN Layer
使用 RBF 近似实现高效计算
"""
def __init__(self, in_features, out_features,
num_grids=5, smooth_threshold=0.1):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.num_grids = num_grids
# 网格点
h = 1.0 / num_grids
grid = torch.linspace(-h * 2, 1 + h * 2, num_grids + 4)
self.register_buffer('grid', grid)
# RBF 参数(可学习)
self.rbf_centers = nn.Parameter(torch.rand(num_grids, in_features))
self.rbf_scales = nn.Parameter(torch.ones(num_grids, in_features))
# 基础激活权重
self.base_weight = nn.Parameter(torch.randn(out_features, in_features))
# 组合权重
self.combine_weight = nn.Parameter(torch.randn(out_features, in_features, num_grids))
# 平滑阈值(用于自适应)
self.smooth_threshold = smooth_threshold
def rbf_gaussian(self, x):
"""
高斯 RBF
φ_c,s(x) = exp(-||x - c||² / (2s²))
"""
# x: (batch, in), centers: (num_grids, in), scales: (num_grids, in)
x_expanded = x.unsqueeze(1) # (batch, 1, in)
c_expanded = self.rbf_centers.unsqueeze(0) # (1, num_grids, in)
s_expanded = self.rbf_scales.unsqueeze(0) # (1, num_grids, in)
# 计算距离
dist = ((x_expanded - c_expanded) ** 2) / (2 * s_expanded ** 2 + 1e-8)
# 高斯 RBF
rbf = torch.exp(-dist.sum(dim=-1)) # (batch, num_grids)
return rbf
def forward(self, x):
"""
Args:
x: (batch, in_features), 范围 [0, 1]
Returns:
y: (batch, out_features)
"""
x = x.clamp(0, 1)
# RBF 激活
rbf_act = self.rbf_gaussian(x) # (batch, num_grids)
# 组合输出
combine = torch.einsum('bg,ogi->bo', rbf_act, self.combine_weight)
# 基础激活(SiLU)
base = torch.einsum('bi,oi->bo', x, torch.nn.functional.silu(self.base_weight))
return combine + base
class FastKAN(nn.Module):
"""
Fast-KAN
完整实现
"""
def __init__(self, layer_dims, num_grids=5):
super().__init__()
self.layers = nn.ModuleList()
for i in range(len(layer_dims) - 1):
self.layers.append(
FastKANLayer(layer_dims[i], layer_dims[i+1], num_grids)
)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x3.3 Chebyshev KAN (CKAN)
使用Chebyshev多项式作为基函数,具有更好的数值特性:
class ChebyshevKANLayer(nn.Module):
"""
Chebyshev KAN Layer
使用 Chebyshev 多项式作为基函数
"""
def __init__(self, in_features, out_features, degree=5):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.degree = degree
# Chebyshev 系数(可学习)
self.coeffs = nn.Parameter(
torch.randn(out_features, in_features, degree + 1)
)
# 基础权重
self.base_weight = nn.Parameter(torch.randn(out_features, in_features))
# 初始化
self._init_coeffs()
def _init_coeffs(self):
# 初始化为较小的值
nn.init.normal_(self.coeffs, std=0.01)
nn.init.normal_(self.base_weight, std=0.1)
def chebyshev_polynomials(self, x):
"""
计算 Chebyshev 多项式值
T_0(x) = 1
T_1(x) = x
T_{n+1}(x) = 2xT_n(x) - T_{n-1}(x)
"""
batch_size, _ = x.shape
# 存储所有阶数的值
T = torch.zeros(batch_size, self.degree + 1, device=x.device)
# T_0 = 1
T[:, 0] = 1
# T_1 = x
if self.degree >= 1:
T[:, 1] = x.squeeze(-1)
# 递推计算
for n in range(1, self.degree):
# T_{n+1} = 2xT_n - T_{n-1}
T[:, n + 1] = 2 * x.squeeze(-1) * T[:, n] - T[:, n - 1]
return T # (batch, degree + 1)
def forward(self, x):
"""
Args:
x: (batch, in_features), 范围 [-1, 1]
Returns:
y: (batch, out_features)
"""
# 映射到 [-1, 1]
x = x * 2 - 1
x = x.clamp(-1, 1)
# 计算 Chebyshev 多项式
T = self.chebyshev_polynomials(x) # (batch, degree + 1)
# 加权求和
cheb_out = torch.einsum('bd,oid->bo', T, self.coeffs)
# 基础激活
base = torch.einsum('bi,oi->bo', x, torch.nn.functional.silu(self.base_weight))
return cheb_out + base4. Graph KAN:图结构数据
4.1 核心思想
Graph KAN将KAN的思想扩展到图结构数据,用于图神经网络中的节点表示学习。4
4.2 设计动机
传统GNN使用:
- MLP:固定的非线性激活
- GAT:注意力加权邻域聚合
Graph KAN引入:
- 可学习的边激活:每条边有独特的激活函数
- 图结构感知的表示:利用拓扑信息指导激活
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
class GraphKANConv(MessagePassing):
"""
Graph KAN 卷积层
在消息传递框架中嵌入 KAN 激活
"""
def __init__(self, in_features, out_features,
num_edges_hidden=5, aggr='add'):
super().__init__(aggr=aggr)
self.in_features = in_features
self.out_features = out_features
self.num_edges_hidden = num_edges_hidden
# 边特征转换(KAN 激活)
self.edge_kan = KANLayer(in_features * 2, num_edges_hidden)
# 边到消息的映射
self.edge_to_msg = nn.Linear(num_edges_hidden, out_features)
# 节点更新
self.node_update = nn.Sequential(
nn.Linear(in_features + out_features, out_features),
nn.LayerNorm(out_features)
)
# 边Dropout
self.edge_dropout = nn.Dropout(0.1)
def forward(self, x, edge_index, edge_attr=None):
"""
Args:
x: (num_nodes, in_features) 节点特征
edge_index: (2, num_edges) 边索引
edge_attr: (num_edges, edge_dim) 边特征(可选)
Returns:
out: (num_nodes, out_features) 更新后的节点特征
"""
# 消息传递
messages = self.propagate(edge_index, x=x, edge_attr=edge_attr)
# 更新节点特征
combined = torch.cat([x, messages], dim=-1)
out = self.node_update(combined)
return out
def message(self, x_i, x_j, edge_attr):
"""
消息构造
边的两个端点特征 -> KAN 激活 -> 消息
"""
# 拼接源节点和目标节点特征
if edge_attr is not None:
# 有边特征
edge_input = torch.cat([edge_attr, x_i, x_j], dim=-1)
else:
# 无边特征,只有节点特征
edge_input = torch.cat([x_i, x_j], dim=-1)
# KAN 激活
kan_act = self.edge_kan(edge_input) # (num_edges, num_edges_hidden)
# 转换为消息
messages = self.edge_to_msg(kan_act) # (num_edges, out_features)
# Dropout
messages = self.edge_dropout(messages)
return messages
class GraphKAN(nn.Module):
"""
Graph KAN 模型
"""
def __init__(self, in_features, hidden_features, out_features,
num_layers=3, dropout=0.5):
super().__init__()
self.num_layers = num_layers
# 输入层
self.input_conv = GraphKANConv(in_features, hidden_features)
# 隐藏层
self.hidden_convs = nn.ModuleList([
GraphKANConv(hidden_features, hidden_features)
for _ in range(num_layers - 2)
])
# 输出层
self.output_conv = GraphKANConv(hidden_features, out_features)
self.dropout = dropout
self.activation = nn.ReLU()
def forward(self, x, edge_index):
"""
Args:
x: (num_nodes, in_features)
edge_index: (2, num_edges)
Returns:
out: (num_nodes, out_features)
"""
# 输入层
x = self.activation(self.input_conv(x, edge_index))
x = F.dropout(x, p=self.dropout, training=self.training)
# 隐藏层
for conv in self.hidden_convs:
x = self.activation(conv(x, edge_index))
x = F.dropout(x, p=self.dropout, training=self.training)
# 输出层
x = self.output_conv(x, edge_index)
return x
class GraphKANWithEdgeFeatures(nn.Module):
"""
带边特征的 Graph KAN
适用于边具有丰富属性的图
"""
def __init__(self, node_features, edge_features, hidden_features,
out_features, num_layers=3):
super().__init__()
# 节点特征编码
self.node_encoder = nn.Linear(node_features, hidden_features)
# 边特征编码
self.edge_encoder = nn.Sequential(
nn.Linear(edge_features, hidden_features),
nn.ReLU(),
nn.Linear(hidden_features, hidden_features)
)
# Graph KAN 层
self.convs = nn.ModuleList([
GraphKANConv(hidden_features, hidden_features)
for _ in range(num_layers)
])
# 分类/回归头
self.head = nn.Linear(hidden_features, out_features)
def forward(self, x, edge_index, edge_attr):
"""
Args:
x: (num_nodes, node_features)
edge_index: (2, num_edges)
edge_attr: (num_edges, edge_features)
Returns:
out: (num_nodes, out_features)
"""
# 编码特征
x = self.node_encoder(x)
edge_attr = self.edge_encoder(edge_attr)
# Graph KAN 层
for conv in self.convs:
x = F.relu(conv(x, edge_index, edge_attr))
# 输出
out = self.head(x)
return out4.3 Graph KAN 的优势
| 特性 | GCN/GAT | Graph KAN |
|---|---|---|
| 边权重 | 静态/简单注意力 | 可学习的非线性函数 |
| 边特征处理 | 简单拼接 | KAN 激活 |
| 表达能力 | 中 | 高 |
| 计算复杂度 |
5. 其他重要变体
5.1 Ada-KAN:自适应变体
class AdaKANLayer(nn.Module):
"""
Adaptive KAN Layer
根据输入自适应选择激活函数类型
"""
def __init__(self, in_features, out_features,
activation_types=['linear', 'relu', 'gelu', 'sin', 'exp']):
super().__init__()
self.activation_types = activation_types
self.num_types = len(activation_types)
# 每种激活类型的权重
self.type_weights = nn.Parameter(
torch.randn(out_features, in_features, self.num_types)
)
# 激活函数库
self.activation_funcs = {
'linear': lambda x: x,
'relu': nn.functional.relu,
'gelu': nn.functional.gelu,
'sin': torch.sin,
'exp': torch.exp,
'log': lambda x: torch.log(x.clamp(min=1e-8)),
}
# 选择器网络
self.selector = nn.Sequential(
nn.Linear(in_features, in_features // 2),
nn.ReLU(),
nn.Linear(in_features // 2, self.num_types),
nn.Softmax(dim=-1)
)
def forward(self, x):
"""
自适应选择激活函数组合
"""
# 计算选择权重
selection = self.selector(x.mean(dim=0, keepdim=True)) # (1, num_types)
# 聚合所有激活函数
outputs = []
for i, act_name in enumerate(self.activation_types):
act_func = self.activation_funcs[act_name]
act_out = act_func(x) # (batch, in)
# 加权
weighted = torch.einsum('bi,oi->bo',
act_out,
self.type_weights[:, :, i])
outputs.append(weighted)
# 加权求和
all_outputs = torch.stack(outputs, dim=-1) # (batch, out, num_types)
selection = selection.unsqueeze(1) # (1, 1, num_types)
out = (all_outputs * selection).sum(dim=-1) # (batch, out)
return out5.2 RNN-KAN:序列建模
class RNNKANCell(nn.Module):
"""
RNN KAN Cell
用于序列建模的 KAN 循环单元
"""
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# 输入到隐状态的 KAN
self.input_kan = KANLayer(input_size, hidden_size)
# 隐状态到隐状态的 KAN
self.hidden_kan = KANLayer(hidden_size, hidden_size)
# 门控机制
self.input_gate = nn.Sequential(
nn.Linear(input_size + hidden_size, hidden_size),
nn.Sigmoid()
)
self.forget_gate = nn.Sequential(
nn.Linear(input_size + hidden_size, hidden_size),
nn.Sigmoid()
)
self.output_gate = nn.Sequential(
nn.Linear(input_size + hidden_size, hidden_size),
nn.Sigmoid()
)
def forward(self, x, h_prev):
"""
Args:
x: (batch, input_size) 当前输入
h_prev: (batch, hidden_size) 上一时刻隐状态
Returns:
h: (batch, hidden_size) 当前隐状态
"""
combined = torch.cat([x, h_prev], dim=-1)
# 门控
i = self.input_gate(combined)
f = self.forget_gate(combined)
o = self.output_gate(combined)
# KAN 激活
kan_in = self.input_kan(x)
kan_hidden = self.hidden_kan(h_prev)
# 细胞状态更新
c = f * h_prev + i * torch.tanh(kan_in + kan_hidden)
# 隐状态输出
h = o * torch.tanh(c)
return h
class RNNKAN(nn.Module):
"""
RNN KAN
完整的循环神经网络
"""
def __init__(self, input_size, hidden_size, num_layers=2):
super().__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.cells = nn.ModuleList([
RNNKANCell(
input_size if i == 0 else hidden_size,
hidden_size
)
for i in range(num_layers)
])
def forward(self, x, h0=None):
"""
Args:
x: (batch, seq_len, input_size)
h0: (num_layers, batch, hidden_size)
Returns:
outputs: (batch, seq_len, hidden_size)
hn: (num_layers, batch, hidden_size)
"""
batch_size, seq_len, _ = x.shape
# 初始化隐状态
if h0 is None:
hn = [torch.zeros(batch_size, self.hidden_size, device=x.device)
for _ in range(self.num_layers)]
else:
hn = [h0[i] for i in range(self.num_layers)]
outputs = []
for t in range(seq_len):
xt = x[:, t, :]
for layer in range(self.num_layers):
ht = self.cells[layer](xt, hn[layer])
hn[layer] = ht
xt = ht
outputs.append(ht)
outputs = torch.stack(outputs, dim=1) # (batch, seq_len, hidden)
hn = torch.stack(hn, dim=0) # (num_layers, batch, hidden)
return outputs, hn6. 变体对比总结
6.1 计算效率对比
| 变体 | 计算复杂度 | 内存占用 | 相对速度 |
|---|---|---|---|
| 原始 KAN | 高 | 1x | |
| FastKAN | 中 | 2-3x | |
| CKAN | 中 | 2-3x | |
| Wav-KAN | 高 | 0.8-1.2x | |
| Ada-KAN | 高 | 0.5-1x |
其中 是网格大小, 是RBF数量, 是Chebyshev阶数, 是小波尺度数, 是激活类型数。
6.2 表达能力对比
| 变体 | 非线性表达能力 | 多尺度分析 | 可解释性 |
|---|---|---|---|
| 原始 KAN | ★★★★ | ★★ | ★★★★ |
| FastKAN | ★★★ | ★★ | ★★★ |
| CKAN | ★★★★ | ★★ | ★★★★ |
| Wav-KAN | ★★★★ | ★★★★★ | ★★★★★ |
| Graph KAN | ★★★★ | ★★ | ★★★ |
| Ada-KAN | ★★★★★ | ★★★ | ★★★ |
6.3 适用场景
| 场景 | 推荐变体 |
|---|---|
| 科学公式发现 | KAN 2.0, CKAN |
| 信号/图像处理 | Wav-KAN |
| 图数据学习 | Graph KAN |
| 大规模训练 | FastKAN |
| 序列建模 | RNN-KAN |
| 需要高表达力 | Ada-KAN |
7. 实践建议
7.1 变体选择指南
def select_kan_variant(task_type, data_scale, interpretability_required):
"""
选择合适的 KAN 变体
Args:
task_type: 'regression', 'classification', 'graph', 'sequence'
data_scale: 'small', 'medium', 'large'
interpretability_required: bool
Returns:
variant: str
"""
if task_type == 'graph':
return 'GraphKAN'
if task_type == 'sequence':
return 'RNN-KAN'
if interpretability_required:
if 'signal' in task_type or 'image' in task_type:
return 'Wav-KAN'
else:
return 'CKAN'
if data_scale == 'large':
return 'FastKAN'
if interpretability_required and data_scale == 'small':
return 'KAN (original)'
# 默认选择 FastKAN
return 'FastKAN'
# 示例
variant = select_kan_variant(
task_type='scientific_regression',
data_scale='medium',
interpretability_required=True
)
print(f"推荐变体: {variant}") # CKAN7.2 超参数设置
KAN_VARIANT_CONFIGS = {
'original': {
'grid_size': 5,
'spline_order': 3,
'lr': 1e-3,
'epochs': 200
},
'FastKAN': {
'num_grids': 8,
'smooth_threshold': 0.1,
'lr': 1e-3,
'epochs': 150
},
'CKAN': {
'degree': 5,
'lr': 5e-4,
'epochs': 200
},
'Wav-KAN': {
'num_scales': 4,
'wavelet_type': 'db4',
'lr': 1e-3,
'epochs': 180
},
'GraphKAN': {
'num_edges_hidden': 5,
'aggr': 'add',
'lr': 1e-3,
'epochs': 300
}
}8. 未来研究方向
8.1 当前局限
- 计算效率:所有变体都面临计算成本问题
- 规模化:在大型数据集上的效率不如Transformer
- 理论深度:表达能力理论还不够完善
- 硬件支持:缺乏专门的硬件加速
8.2 潜在突破方向
- 硬件感知设计:针对GPU/TPU优化
- 稀疏化:结合剪枝技术
- 与其他架构融合:如KAN-Transformer混合
- 自动化设计:AutoML用于KAN架构搜索
参考
相关阅读
- 现代MLP架构 — KAN基础介绍
- KAN 2.0科学发现 — KAN 2.0深度解析
- GCN详解 — 图神经网络基础
- 小波变换 — 时频分析基础
Footnotes
-
Liu, Z., et al. (2024). “KAN: Kolmogorov-Arnold Networks”. arXiv:2404.19756. ↩
-
Wav-KAN authors. (2024). “Wav-KAN: Wavelet Kolmogorov-Arnold Networks”. arXiv:2405.12832. ↩
-
FastKAN authors. (2024). “FastKAN: Efficient KAN via Radial Basis Functions”. ↩
-
Graph KAN authors. (2024). “Graph Kolmogorov-Arnold Networks”. ↩