KAN论文笔记

KAN(Kolmogorov-Arnold Networks)相比于MLP的改进在于激活函数的可学习,可以用更少量的参数来做更深层次的拟合,极大减少了网络参数,增强了网络的可解释性,KAN缺点在于训练过于缓慢,这篇笔记主要是对论文主干的翻译,我们先看看KAN的论文。(Paper有48页,慢慢啃)

参考链接:

论文:[2404.19756] KAN: Kolmogorov-Arnold Networks (arxiv.org)

代码(Python实现):https://github.com/KindXiaoming/pykan

〇、摘要

基于Kolmogorov-Arnold表示理论提出了KANs,有望作为多层感知机(MLP)的替代。MLP的每个神经元节点有固定的激活函数,而KAN拥有可学习的激活函数。KAN完全没有线性权重,每个权重参数都被一个参数化的单变量样条函数替代。论文会展示这个简单的改变会使得KAN在精度和可解释性上的显著提升。在精度方面,数据拟合和偏微分方程求解上,更小的KANs可以比与其大得多的MLPs表现相当。理论和实践说明KAN比MLP具有更快的神经缩放规则。在可解释性方面,KAN可以凭直觉可视化并且和人类用户轻易进行交互。通过两个数学和物理的例子说明,KAN可以被用于帮助科学家发现数学和物理定律。总而言之,KANs有希望替代MLP,有机会环节现在深度学习对MLP的依赖。

一、简介

MLP大家都熟悉,就不贴原文说明了。主要是MLP固定了激活函数,KAN具有可学习的激活函数

KAN和MLP一样具有全接连的结构,每个节点的激活函数是一个简单的样条函数,节点输出就是将所有信号求和,除激活函数外没有其它非线性操作。这里可能有会有人每个节点的激活函数不同会使得MLP太贵,这里论文特别做出解释,这里激活函数是参数化成样条函数的,KAN实际的计算并不会比MLP高。比如,对于PDE(偏微分方程)求解,2层10宽的KAN具有比4层100宽的MLP有100倍的准确性,KAN的参数效率是MLP的100倍。

尽管KAN有优雅的数学解释,它也只不过是MLP和splines激活的结合。他们相互取长补短。spline可以精确处理低维度函数,容易做局部的调整,可以在不同分辨率之间做切换。然而splines有严重的维度灾难问题,因为它无法使用一些组合结构。MLP具有特征学习能力,维度灾难问题更少,它在低维度上不如spline精确,因为它无法优化单变量函数。

准确地学习一个函数,一个模型不应该只学习组成结构(外自由度),还应该近似好单变量函数(内部自由度)。KAN同时具有MLP外部自由度和splines内部自由度,因此KAN不仅能学习特征,还可以把学习的特征优化到更加准确,例如我们给一个高维函数:

f(x_1,...,x_N)=\exp(\frac{1}{N}\Sigma_{i=1}^N sin^2(x_i))

样条会因为维度灾难,在N很大的时候失败,MLP可以学习该函数结构的潜在表示,但是在使用ReLU函数激活时做exp和sin函数近似时效率很低。而KAN具有二者的优点。

下面这张图右上角就是一个精度和可解释性的例子,对于exp(J0(20x)+y^2))三个神经元的激活函数分别做了J0、square、exp操作,而在J0函数的几何精度上,两层KAN超过了五层MLP。

KAN命名来自两位数学家Kolmogorov和Arnold。文章结构也在下图展示出来:

二、KAN

2.1 Kolmogorov-Arnold表示理论

Vladimir Arnold和Andrey Kolmogorov证明了:如果f是有限界上的多变量连续函数,那么f可以写作有限个单变量连续函数的二值求和操作。更加具体一些,对于平滑函数f,在[0,1]实数域上

f(\textbf{x})=f(x_1,...,x_n)=\Sigma_{q=1}^{2n+1}\Phi_q(\Sigma_{p=1}^n\phi_{p,q}(x_p))

ϕq,p : [0,1] → R and Φq : R → R. 我们可以从下图来直观了解下上面公式所表示的计算过程以及公式各个符号的含义,其可表示为一个两层网络,p对应输入id,q对应输出id,其中phi倍参数化为B样条的基数。

右图是说B样条参数化可以在不同粒度上切换。

里面加法是唯一的多变量函数,其它函数可以写作单变量函数的和。有人可能会天真的认为这对机器学习来说是个好消息:学习一个高维函数可以归结为学习一个多项式的一维函数。然而,这些一维函数可以是非光滑甚至分形的,所以它们在实践中可能是不可学习的。正因如此,Kolmogorov-Arnold表示定理,基本上在机器学习中不会使用,理论上被认为是合理的,但实际上毫无用处。

而论文对Kolmogorov-Arnold定理在机器学习中的使用比较乐观。首先,我们需要不拘泥于原文中上面的方程,上面的方程有两个非线性层,以及隐藏了一个小的数据项2n+1,论文会泛化该网络到任意深度和宽度。

2.2 KAN的架构

假设我们有一个数据几何任务需要找到满足所有数据对的映射方程f,2.1节的公式告诉我们完成该任务时可以找到近似的单变量函数phi(p,q)和Phi(q)。于是KAN受此启发按照2.1中的公式来进行参数化。通过学习 B-spline 基函数的可学习因子来完成参数化为1D B-spline曲线的学习。现在KAN的原型的计算图可以描述成图0.1那样了。n时输入维度,2n+1是隐藏层维度。

接下来论文将KAN泛化到更宽和更深。因为Kolmogorov-Arnold表示只显示其有两层。参考MLP,一旦定义好一个层,我们可以通过不断堆叠来使得网络更深,于是论文也定义了KAN layer:

\Phi=\{\phi_{q,p}\}   \ \ \ \ \ \ p=1,2,...n_{in}, \ \ \ \ \ \ q=1,2,...,n_{out}

phi函数拥有可学习的参数,下文会详细介绍。在Kolmogov-Arnold理论中,内层phi的n_in = n, n_out=2n+1, 外层Phi的n_in=2n+1, n_out=1,于是Kolmogorov-Arnold表示可以倍简单表示成两个KAN layer。 于是我们通过不断堆叠KAN layer可以使得网络更深。

论文接下来借助MLP中的各种记号来定义了KAN中的各个参数,i表示输入,j表示输出,l表示layer,总体来说就是用phi替代了激活函数,并且phi中有可学习参数。

(这一段都是公式,也比较简单就直接贴原文看了,这一段我们需要知道Phi_l是一个函数矩阵,后面会用到。)

接下来关注实现细节,KAN layer看起来极其简单,是平凡的好优化的。论文介绍了几个关键技巧:残差激活函数,初始规模,更新spline grids。

残差激活函数,我们定义了b函数,有些像残差连接,激活函数是spline和b函数的和。spline还能输可以倍参数化成B-spline的线性组合,其中c是可训练的。原则上w是不必要的:但是为了更好的控制激活层的幅度大小,论文还是引入了因子。

\phi(x)=w(b(x)+spline(x))  \\
b(x)=silu(x)=\frac{x}{1+e^{-x}}  \\
spline(x)=\Sigma_{i}c_iB_i(x)

初始scales:每个激活函数都被初始化成 spline(x)≈0,w使用Xavier初始化。

spline grids的更新:为了解决训练期间splines激活值超过固定区间的问题,每个grid根据输入值被动态激活。

参数数量:深度L,宽度n,spline的阶数k(通常使用k=3)在G个区间上(有G+1个个点),那么总共就有 O(N^2 L(G+k))~O(N^2 LG)个参数。而MLP则有O(N^2 L)个参数。

2.3 KAN的近似能力和标度法制

通常更深的表达会体现更平滑的激活函数的优势,我们举一个四变量函数的例子:

f(x_1,x_2,x_3,x_4)=\exp(\sin(x_1^2+x_2^2)+\sin(x_3^2+x_4^2))

这个例子可以被表示成[4,2,1,1]的三层KAN。

定理2.1 (近似理论,KAT)假定一个函数f(x)表示成

f=(\Phi_{L-1}\circ\Phi_{L-2}\circ...\circ\Phi_1\circ\Phi_0)\text{x}

其中每个Phi都是k+1次连续可微函数,于是存在一个和f及其表示有关的常量C,用网络大小G表示近似边界,存在k阶B样条函数\Phi^G_{l,i,j},对于任意0≤m≤k,我们找到边界:

||f-(\Phi_{L-1}^G\circ\Phi_{L-2}^G\circ...\circ\Phi_1^G\circ\Phi_0^G)\text{x}||_{C^m}<=CG^{-k-1+m}

这里我们使用记号C^m范数来衡量m阶倒数的大小:

||g||_{C^m}=\max \sup |D^\beta g(x)|

定理的证明就不翻译了,总之是说KAN通过有限个网格大小可以很好的近似函数,残差旅和维度无关,这样可以一定程度上对抗维度灾难。

神经尺度规律是随着模型参数的增加,测试loss减少的现象

2.4 网格拓展(Grid Extension)

通常,spline grid区间越多,拟合越准确,这个特点被kan很好的继承下来。并且先训练一个小一些的KAN,它的参数规模可以通过调整grid size很容易的拓展成大一些的KAN。

上面的图2.2描述了如果进行grid extentsion。首先对k阶b样条[a,b]区间内,一个粗粒度的格点有G1个区间,通过网格拓展扩大到G2个区间,这个过程通过最小二乘吗优化细粒度函数和粗粒度函数之间的距离来得到细粒度函数的初始参数:

{c'_j}=\argmin _{c1_j}\textbf{E}(\sum_{j=0}^{G_2+k-1}c'_jB'_j(x)-\sum_{i=0}^{G1+k-1}c_iB_i(x))^2

图2.3坐立一个例子,通过不断增加grid数量,发现了overfitting和underfitting现象,说明增加grid参数增加了网络参数规模。在grid>1000时训练很慢出现了bad landscape loss。同时更小的kan在网络扩展上表现更好,这里对比的时[2,5,1]和[2,1,1]规模的kan。左下图总结了随着grid size增加,loss均方差减小的规律,KAN比MLP更容易做网络大小的缩放。

KAN的结果可以做符号回归,尝试还原出函数的符号表达。大致有这样几个过程,稀疏化训练、剪枝、设置符号函数、进一步训练、输出符号方程。

三、KAN的准确率

KAN和MLP的对比

不同函数的对比

Feynman数据集

解偏微分方程

四、KAN的可解释性

演示几个简单的符号任务

如何选择KAN和MLP

发表评论