概述

自2024年4月KAN(Kolmogorov-Arnold Networks)论文发表以来,研究社区迅速涌现出多种变体,旨在解决原始KAN的局限性或扩展其应用场景。1本文系统综述主要KAN变体,涵盖Wav-KAN(小波激活)、Graph KAN(图结构数据)、FastKAN(计算效率)、CKAN(Chebyshev多项式)以及其他重要变体。


1. 变体分类体系

1.1 按设计目标分类

类别变体核心目标
可解释性增强Wav-KAN, Spline-KAN更清晰的激活函数可视化
效率优化FastKAN, CKAN, GPU-KAN降低计算复杂度
结构扩展Graph KAN, RNN-KAN适配图结构/RNN
精度提升Kolmogorov-GAN, Ada-KAN改善拟合能力

1.2 变体时间线

2024.04  KAN (MIT)
2024.05  Wav-KAN (小波激活)
2024.06  FastKAN (计算效率)
2024.07  Graph KAN (图数据)
2024.08  KAN 2.0 (科学应用)
2024.09  CKAN (Chebyshev多项式)
2024.10  Ada-KAN (自适应)
2024.11  RNN-KAN (序列建模)

2. Wav-KAN:小波激活函数

2.1 核心思想

Wav-KAN(Wavelet Kolmogorov-Arnold Networks)用小波函数替代B-样条作为激活函数,提升可解释性和多尺度分析能力。2

为什么使用小波?

  1. 多尺度分析:小波天然具有多尺度特性
  2. 稀疏表示:小波变换产生稀疏表示
  3. 时频局部化:同时具有时间和频率局部化能力
  4. 可解释性:小波基函数具有明确的物理意义

2.2 小波理论基础

Haar小波

最简单的正交小波:

Daubechies小波

更平滑的小波,具有紧支撑性。阶Daubechies小波记为

import torch
import torch.nn as nn
import numpy as np
 
class WaveletBasis(nn.Module):
    """
    小波基函数
    
    支持多种小波类型
    """
    def __init__(self, wavelet_type='db4', num_scales=4):
        super().__init__()
        self.wavelet_type = wavelet_type
        self.num_scales = num_scales
        
        # 预计算小波系数
        self.register_buffer('wavelet_coeffs', self._get_wavelet_coeffs())
    
    def _get_wavelet_coeffs(self):
        """获取小波系数"""
        if self.wavelet_type == 'haar':
            # Haar小波
            return torch.tensor([1.0, -1.0])
        
        elif self.wavelet_type == 'db4':
            # Daubechies-4小波(4阶消失矩)
            return torch.tensor([0.4830, 0.8365, 0.2241, -0.1294])
        
        elif self.wavelet_type == 'sym4':
            # Symlet-4小波
            return torch.tensor([0.0228, -0.0633, -0..ext-0.0651, 0.1830,
                               0.9111, 0.3514])
        
        else:
            raise ValueError(f"Unknown wavelet type: {self.wavelet_type}")
    
    def forward(self, x):
        """
        计算小波基函数值
        """
        coeffs = self.wavelet_coeffs
        
        # 简化的离散小波变换
        output = torch.zeros_like(x)
        
        for i, c in enumerate(coeffs):
            shifted = torch.roll(x, shifts=i, dims=-1)
            output += c * shifted
        
        return output
 
 
class WavKANLayer(nn.Module):
    """
    Wav-KAN Layer
    
    使用小波激活函数的 KAN Layer
    """
    def __init__(self, in_features, out_features, 
                 num_scales=4, wavelet_type='db4'):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.num_scales = num_scales
        self.wavelet_type = wavelet_type
        
        # 多尺度小波基
        self.wavelet_bases = nn.ModuleList([
            WaveletBasis(wavelet_type, scale+1) 
            for scale in range(num_scales)
        ])
        
        # 小波系数(可学习)
        self.wavelet_coeffs = nn.Parameter(
            torch.randn(out_features, in_features, num_scales)
        )
        
        # 尺度函数(近似部分)
        self.scale_coeffs = nn.Parameter(
            torch.randn(out_features, in_features)
        )
        
        # 初始化
        self._init_weights()
    
    def _init_weights(self):
        nn.init.kaiming_normal_(self.wavelet_coeffs)
        nn.init.normal_(self.scale_coeffs, std=0.1)
    
    def forward(self, x):
        """
        Args:
            x: (batch, in_features)
        Returns:
            y: (batch, out_features)
        """
        # 确保输入在 [0, 1]
        x = x.clamp(0, 1)
        
        # 多尺度小波激活
        wavelet_outputs = []
        for scale in range(self.num_scales):
            # 计算该尺度的小波变换
            psi = self.wavelet_bases[scale](x)  # (batch, in)
            
            # 加权组合
            weighted = psi.unsqueeze(1) * self.wavelet_coeffs[:, :, scale]  # (batch, out, in)
            wavelet_outputs.append(weighted.sum(dim=-1))  # (batch, out)
        
        # 尺度函数(低频部分)
        scale_output = torch.einsum('bi,oi->bo', x, self.scale_coeffs)
        
        # 融合多尺度输出
        wavelet_stack = torch.stack(wavelet_outputs, dim=-1)  # (batch, out, scales)
        wavelet_sum = wavelet_stack.sum(dim=-1)  # (batch, out)
        
        return scale_output + wavelet_sum
 
 
class WavKAN(nn.Module):
    """
    完整的 Wav-KAN 模型
    """
    def __init__(self, layer_dims, num_scales=4, wavelet_type='db4'):
        super().__init__()
        
        self.layers = nn.ModuleList()
        for i in range(len(layer_dims) - 1):
            self.layers.append(
                WavKANLayer(layer_dims[i], layer_dims[i+1], 
                           num_scales, wavelet_type)
            )
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

2.3 Wav-KAN vs B-Spline KAN

特性B-Spline KANWav-KAN
基函数B-样条小波
多尺度分析
稀疏表示
计算复杂度
可解释性高(小波有物理意义)
适用场景一般信号/图像

3. FastKAN:高效计算变体

3.1 动机与挑战

原始KAN的主要计算瓶颈在于B-样条激活函数的计算。FastKAN通过近似和优化来解决这个问题。3

3.2 RBF-Net 近似

FastKAN使用**径向基函数(RBF)**近似B-样条:

class FastKANLayer(nn.Module):
    """
    Fast-KAN Layer
    
    使用 RBF 近似实现高效计算
    """
    def __init__(self, in_features, out_features, 
                 num_grids=5, smooth_threshold=0.1):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.num_grids = num_grids
        
        # 网格点
        h = 1.0 / num_grids
        grid = torch.linspace(-h * 2, 1 + h * 2, num_grids + 4)
        self.register_buffer('grid', grid)
        
        # RBF 参数(可学习)
        self.rbf_centers = nn.Parameter(torch.rand(num_grids, in_features))
        self.rbf_scales = nn.Parameter(torch.ones(num_grids, in_features))
        
        # 基础激活权重
        self.base_weight = nn.Parameter(torch.randn(out_features, in_features))
        
        # 组合权重
        self.combine_weight = nn.Parameter(torch.randn(out_features, in_features, num_grids))
        
        # 平滑阈值(用于自适应)
        self.smooth_threshold = smooth_threshold
    
    def rbf_gaussian(self, x):
        """
        高斯 RBF
        
        φ_c,s(x) = exp(-||x - c||² / (2s²))
        """
        # x: (batch, in), centers: (num_grids, in), scales: (num_grids, in)
        x_expanded = x.unsqueeze(1)  # (batch, 1, in)
        c_expanded = self.rbf_centers.unsqueeze(0)  # (1, num_grids, in)
        s_expanded = self.rbf_scales.unsqueeze(0)  # (1, num_grids, in)
        
        # 计算距离
        dist = ((x_expanded - c_expanded) ** 2) / (2 * s_expanded ** 2 + 1e-8)
        
        # 高斯 RBF
        rbf = torch.exp(-dist.sum(dim=-1))  # (batch, num_grids)
        
        return rbf
    
    def forward(self, x):
        """
        Args:
            x: (batch, in_features), 范围 [0, 1]
        Returns:
            y: (batch, out_features)
        """
        x = x.clamp(0, 1)
        
        # RBF 激活
        rbf_act = self.rbf_gaussian(x)  # (batch, num_grids)
        
        # 组合输出
        combine = torch.einsum('bg,ogi->bo', rbf_act, self.combine_weight)
        
        # 基础激活(SiLU)
        base = torch.einsum('bi,oi->bo', x, torch.nn.functional.silu(self.base_weight))
        
        return combine + base
 
 
class FastKAN(nn.Module):
    """
    Fast-KAN
    
    完整实现
    """
    def __init__(self, layer_dims, num_grids=5):
        super().__init__()
        
        self.layers = nn.ModuleList()
        for i in range(len(layer_dims) - 1):
            self.layers.append(
                FastKANLayer(layer_dims[i], layer_dims[i+1], num_grids)
            )
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

3.3 Chebyshev KAN (CKAN)

使用Chebyshev多项式作为基函数,具有更好的数值特性:

class ChebyshevKANLayer(nn.Module):
    """
    Chebyshev KAN Layer
    
    使用 Chebyshev 多项式作为基函数
    """
    def __init__(self, in_features, out_features, degree=5):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.degree = degree
        
        # Chebyshev 系数(可学习)
        self.coeffs = nn.Parameter(
            torch.randn(out_features, in_features, degree + 1)
        )
        
        # 基础权重
        self.base_weight = nn.Parameter(torch.randn(out_features, in_features))
        
        # 初始化
        self._init_coeffs()
    
    def _init_coeffs(self):
        # 初始化为较小的值
        nn.init.normal_(self.coeffs, std=0.01)
        nn.init.normal_(self.base_weight, std=0.1)
    
    def chebyshev_polynomials(self, x):
        """
        计算 Chebyshev 多项式值
        
        T_0(x) = 1
        T_1(x) = x
        T_{n+1}(x) = 2xT_n(x) - T_{n-1}(x)
        """
        batch_size, _ = x.shape
        
        # 存储所有阶数的值
        T = torch.zeros(batch_size, self.degree + 1, device=x.device)
        
        # T_0 = 1
        T[:, 0] = 1
        
        # T_1 = x
        if self.degree >= 1:
            T[:, 1] = x.squeeze(-1)
        
        # 递推计算
        for n in range(1, self.degree):
            # T_{n+1} = 2xT_n - T_{n-1}
            T[:, n + 1] = 2 * x.squeeze(-1) * T[:, n] - T[:, n - 1]
        
        return T  # (batch, degree + 1)
    
    def forward(self, x):
        """
        Args:
            x: (batch, in_features), 范围 [-1, 1]
        Returns:
            y: (batch, out_features)
        """
        # 映射到 [-1, 1]
        x = x * 2 - 1
        x = x.clamp(-1, 1)
        
        # 计算 Chebyshev 多项式
        T = self.chebyshev_polynomials(x)  # (batch, degree + 1)
        
        # 加权求和
        cheb_out = torch.einsum('bd,oid->bo', T, self.coeffs)
        
        # 基础激活
        base = torch.einsum('bi,oi->bo', x, torch.nn.functional.silu(self.base_weight))
        
        return cheb_out + base

4. Graph KAN:图结构数据

4.1 核心思想

Graph KAN将KAN的思想扩展到图结构数据,用于图神经网络中的节点表示学习。4

4.2 设计动机

传统GNN使用:

  • MLP:固定的非线性激活
  • GAT:注意力加权邻域聚合

Graph KAN引入:

  • 可学习的边激活:每条边有独特的激活函数
  • 图结构感知的表示:利用拓扑信息指导激活
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
 
 
class GraphKANConv(MessagePassing):
    """
    Graph KAN 卷积层
    
    在消息传递框架中嵌入 KAN 激活
    """
    def __init__(self, in_features, out_features, 
                 num_edges_hidden=5, aggr='add'):
        super().__init__(aggr=aggr)
        
        self.in_features = in_features
        self.out_features = out_features
        self.num_edges_hidden = num_edges_hidden
        
        # 边特征转换(KAN 激活)
        self.edge_kan = KANLayer(in_features * 2, num_edges_hidden)
        
        # 边到消息的映射
        self.edge_to_msg = nn.Linear(num_edges_hidden, out_features)
        
        # 节点更新
        self.node_update = nn.Sequential(
            nn.Linear(in_features + out_features, out_features),
            nn.LayerNorm(out_features)
        )
        
        # 边Dropout
        self.edge_dropout = nn.Dropout(0.1)
    
    def forward(self, x, edge_index, edge_attr=None):
        """
        Args:
            x: (num_nodes, in_features) 节点特征
            edge_index: (2, num_edges) 边索引
            edge_attr: (num_edges, edge_dim) 边特征(可选)
        Returns:
            out: (num_nodes, out_features) 更新后的节点特征
        """
        # 消息传递
        messages = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        
        # 更新节点特征
        combined = torch.cat([x, messages], dim=-1)
        out = self.node_update(combined)
        
        return out
    
    def message(self, x_i, x_j, edge_attr):
        """
        消息构造
        
        边的两个端点特征 -> KAN 激活 -> 消息
        """
        # 拼接源节点和目标节点特征
        if edge_attr is not None:
            # 有边特征
            edge_input = torch.cat([edge_attr, x_i, x_j], dim=-1)
        else:
            # 无边特征,只有节点特征
            edge_input = torch.cat([x_i, x_j], dim=-1)
        
        # KAN 激活
        kan_act = self.edge_kan(edge_input)  # (num_edges, num_edges_hidden)
        
        # 转换为消息
        messages = self.edge_to_msg(kan_act)  # (num_edges, out_features)
        
        # Dropout
        messages = self.edge_dropout(messages)
        
        return messages
 
 
class GraphKAN(nn.Module):
    """
    Graph KAN 模型
    """
    def __init__(self, in_features, hidden_features, out_features, 
                 num_layers=3, dropout=0.5):
        super().__init__()
        
        self.num_layers = num_layers
        
        # 输入层
        self.input_conv = GraphKANConv(in_features, hidden_features)
        
        # 隐藏层
        self.hidden_convs = nn.ModuleList([
            GraphKANConv(hidden_features, hidden_features)
            for _ in range(num_layers - 2)
        ])
        
        # 输出层
        self.output_conv = GraphKANConv(hidden_features, out_features)
        
        self.dropout = dropout
        self.activation = nn.ReLU()
    
    def forward(self, x, edge_index):
        """
        Args:
            x: (num_nodes, in_features)
            edge_index: (2, num_edges)
        Returns:
            out: (num_nodes, out_features)
        """
        # 输入层
        x = self.activation(self.input_conv(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # 隐藏层
        for conv in self.hidden_convs:
            x = self.activation(conv(x, edge_index))
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        # 输出层
        x = self.output_conv(x, edge_index)
        
        return x
 
 
class GraphKANWithEdgeFeatures(nn.Module):
    """
    带边特征的 Graph KAN
    
    适用于边具有丰富属性的图
    """
    def __init__(self, node_features, edge_features, hidden_features, 
                 out_features, num_layers=3):
        super().__init__()
        
        # 节点特征编码
        self.node_encoder = nn.Linear(node_features, hidden_features)
        
        # 边特征编码
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, hidden_features)
        )
        
        # Graph KAN 层
        self.convs = nn.ModuleList([
            GraphKANConv(hidden_features, hidden_features)
            for _ in range(num_layers)
        ])
        
        # 分类/回归头
        self.head = nn.Linear(hidden_features, out_features)
    
    def forward(self, x, edge_index, edge_attr):
        """
        Args:
            x: (num_nodes, node_features)
            edge_index: (2, num_edges)
            edge_attr: (num_edges, edge_features)
        Returns:
            out: (num_nodes, out_features)
        """
        # 编码特征
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph KAN 层
        for conv in self.convs:
            x = F.relu(conv(x, edge_index, edge_attr))
        
        # 输出
        out = self.head(x)
        
        return out

4.3 Graph KAN 的优势

特性GCN/GATGraph KAN
边权重静态/简单注意力可学习的非线性函数
边特征处理简单拼接KAN 激活
表达能力
计算复杂度

5. 其他重要变体

5.1 Ada-KAN:自适应变体

class AdaKANLayer(nn.Module):
    """
    Adaptive KAN Layer
    
    根据输入自适应选择激活函数类型
    """
    def __init__(self, in_features, out_features, 
                 activation_types=['linear', 'relu', 'gelu', 'sin', 'exp']):
        super().__init__()
        
        self.activation_types = activation_types
        self.num_types = len(activation_types)
        
        # 每种激活类型的权重
        self.type_weights = nn.Parameter(
            torch.randn(out_features, in_features, self.num_types)
        )
        
        # 激活函数库
        self.activation_funcs = {
            'linear': lambda x: x,
            'relu': nn.functional.relu,
            'gelu': nn.functional.gelu,
            'sin': torch.sin,
            'exp': torch.exp,
            'log': lambda x: torch.log(x.clamp(min=1e-8)),
        }
        
        # 选择器网络
        self.selector = nn.Sequential(
            nn.Linear(in_features, in_features // 2),
            nn.ReLU(),
            nn.Linear(in_features // 2, self.num_types),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, x):
        """
        自适应选择激活函数组合
        """
        # 计算选择权重
        selection = self.selector(x.mean(dim=0, keepdim=True))  # (1, num_types)
        
        # 聚合所有激活函数
        outputs = []
        for i, act_name in enumerate(self.activation_types):
            act_func = self.activation_funcs[act_name]
            act_out = act_func(x)  # (batch, in)
            
            # 加权
            weighted = torch.einsum('bi,oi->bo', 
                                   act_out, 
                                   self.type_weights[:, :, i])
            outputs.append(weighted)
        
        # 加权求和
        all_outputs = torch.stack(outputs, dim=-1)  # (batch, out, num_types)
        selection = selection.unsqueeze(1)  # (1, 1, num_types)
        
        out = (all_outputs * selection).sum(dim=-1)  # (batch, out)
        
        return out

5.2 RNN-KAN:序列建模

class RNNKANCell(nn.Module):
    """
    RNN KAN Cell
    
    用于序列建模的 KAN 循环单元
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # 输入到隐状态的 KAN
        self.input_kan = KANLayer(input_size, hidden_size)
        
        # 隐状态到隐状态的 KAN
        self.hidden_kan = KANLayer(hidden_size, hidden_size)
        
        # 门控机制
        self.input_gate = nn.Sequential(
            nn.Linear(input_size + hidden_size, hidden_size),
            nn.Sigmoid()
        )
        self.forget_gate = nn.Sequential(
            nn.Linear(input_size + hidden_size, hidden_size),
            nn.Sigmoid()
        )
        self.output_gate = nn.Sequential(
            nn.Linear(input_size + hidden_size, hidden_size),
            nn.Sigmoid()
        )
    
    def forward(self, x, h_prev):
        """
        Args:
            x: (batch, input_size) 当前输入
            h_prev: (batch, hidden_size) 上一时刻隐状态
        Returns:
            h: (batch, hidden_size) 当前隐状态
        """
        combined = torch.cat([x, h_prev], dim=-1)
        
        # 门控
        i = self.input_gate(combined)
        f = self.forget_gate(combined)
        o = self.output_gate(combined)
        
        # KAN 激活
        kan_in = self.input_kan(x)
        kan_hidden = self.hidden_kan(h_prev)
        
        # 细胞状态更新
        c = f * h_prev + i * torch.tanh(kan_in + kan_hidden)
        
        # 隐状态输出
        h = o * torch.tanh(c)
        
        return h
 
 
class RNNKAN(nn.Module):
    """
    RNN KAN
    
    完整的循环神经网络
    """
    def __init__(self, input_size, hidden_size, num_layers=2):
        super().__init__()
        
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        
        self.cells = nn.ModuleList([
            RNNKANCell(
                input_size if i == 0 else hidden_size,
                hidden_size
            )
            for i in range(num_layers)
        ])
    
    def forward(self, x, h0=None):
        """
        Args:
            x: (batch, seq_len, input_size)
            h0: (num_layers, batch, hidden_size)
        Returns:
            outputs: (batch, seq_len, hidden_size)
            hn: (num_layers, batch, hidden_size)
        """
        batch_size, seq_len, _ = x.shape
        
        # 初始化隐状态
        if h0 is None:
            hn = [torch.zeros(batch_size, self.hidden_size, device=x.device)
                  for _ in range(self.num_layers)]
        else:
            hn = [h0[i] for i in range(self.num_layers)]
        
        outputs = []
        
        for t in range(seq_len):
            xt = x[:, t, :]
            
            for layer in range(self.num_layers):
                ht = self.cells[layer](xt, hn[layer])
                hn[layer] = ht
                xt = ht
            
            outputs.append(ht)
        
        outputs = torch.stack(outputs, dim=1)  # (batch, seq_len, hidden)
        hn = torch.stack(hn, dim=0)  # (num_layers, batch, hidden)
        
        return outputs, hn

6. 变体对比总结

6.1 计算效率对比

变体计算复杂度内存占用相对速度
原始 KAN1x
FastKAN2-3x
CKAN2-3x
Wav-KAN0.8-1.2x
Ada-KAN0.5-1x

其中 是网格大小, 是RBF数量, 是Chebyshev阶数, 是小波尺度数, 是激活类型数。

6.2 表达能力对比

变体非线性表达能力多尺度分析可解释性
原始 KAN★★★★★★★★★★
FastKAN★★★★★★★★
CKAN★★★★★★★★★★
Wav-KAN★★★★★★★★★★★★★★
Graph KAN★★★★★★★★★
Ada-KAN★★★★★★★★★★★

6.3 适用场景

场景推荐变体
科学公式发现KAN 2.0, CKAN
信号/图像处理Wav-KAN
图数据学习Graph KAN
大规模训练FastKAN
序列建模RNN-KAN
需要高表达力Ada-KAN

7. 实践建议

7.1 变体选择指南

def select_kan_variant(task_type, data_scale, interpretability_required):
    """
    选择合适的 KAN 变体
    
    Args:
        task_type: 'regression', 'classification', 'graph', 'sequence'
        data_scale: 'small', 'medium', 'large'
        interpretability_required: bool
    
    Returns:
        variant: str
    """
    if task_type == 'graph':
        return 'GraphKAN'
    
    if task_type == 'sequence':
        return 'RNN-KAN'
    
    if interpretability_required:
        if 'signal' in task_type or 'image' in task_type:
            return 'Wav-KAN'
        else:
            return 'CKAN'
    
    if data_scale == 'large':
        return 'FastKAN'
    
    if interpretability_required and data_scale == 'small':
        return 'KAN (original)'
    
    # 默认选择 FastKAN
    return 'FastKAN'
 
 
# 示例
variant = select_kan_variant(
    task_type='scientific_regression',
    data_scale='medium',
    interpretability_required=True
)
print(f"推荐变体: {variant}")  # CKAN

7.2 超参数设置

KAN_VARIANT_CONFIGS = {
    'original': {
        'grid_size': 5,
        'spline_order': 3,
        'lr': 1e-3,
        'epochs': 200
    },
    'FastKAN': {
        'num_grids': 8,
        'smooth_threshold': 0.1,
        'lr': 1e-3,
        'epochs': 150
    },
    'CKAN': {
        'degree': 5,
        'lr': 5e-4,
        'epochs': 200
    },
    'Wav-KAN': {
        'num_scales': 4,
        'wavelet_type': 'db4',
        'lr': 1e-3,
        'epochs': 180
    },
    'GraphKAN': {
        'num_edges_hidden': 5,
        'aggr': 'add',
        'lr': 1e-3,
        'epochs': 300
    }
}

8. 未来研究方向

8.1 当前局限

  1. 计算效率:所有变体都面临计算成本问题
  2. 规模化:在大型数据集上的效率不如Transformer
  3. 理论深度:表达能力理论还不够完善
  4. 硬件支持:缺乏专门的硬件加速

8.2 潜在突破方向

  1. 硬件感知设计:针对GPU/TPU优化
  2. 稀疏化:结合剪枝技术
  3. 与其他架构融合:如KAN-Transformer混合
  4. 自动化设计:AutoML用于KAN架构搜索

参考


相关阅读

Footnotes

  1. Liu, Z., et al. (2024). “KAN: Kolmogorov-Arnold Networks”. arXiv:2404.19756.

  2. Wav-KAN authors. (2024). “Wav-KAN: Wavelet Kolmogorov-Arnold Networks”. arXiv:2405.12832.

  3. FastKAN authors. (2024). “FastKAN: Efficient KAN via Radial Basis Functions”.

  4. Graph KAN authors. (2024). “Graph Kolmogorov-Arnold Networks”.