简单记录一下Mamba相关几篇paper的创新和改进过程,首先SSM把RNN和CNN结合起来,可以像RNN一样推理(更快),也可以像CNN一样训练。然而SSM依然无法像Attention一样捕捉长程依赖,于是S4通过改进SSM中的矩阵运算将运算复杂度降低,这样就可以实现一个更大规模的输入。Mamba(S6)进一步解决了输入选择性、并行结算、结构简化的问题。
一、SSM
定义:
- 输入序列x
- 隐藏状态h
- 输出序列y
h'(t)=Ah(t)+Bx(t) \\ y(t)=Ch(t)+Dx(t)
图解计算流程:
二、S4
Structured State Space for Sequences
SSM和CNN、RNN一样无法有效捕捉长程依赖,解决方法是使用HiPPO矩阵替代A矩阵
矩阵分解:使用两个长度为N的矩阵PQ就可完成A矩阵相关运算,复杂度由O(N^2)降低到O(N)
A=VVV-PQ^T=V(A-(V^*P)(V^*Q)^*)V^*
三、S6(Mamba)
S4与SSM的问题:无选择性。ABC矩阵在训练完成时就会固定,不论输入时什么,都会通过完全一样的A,这样导致ABC无法根据输入做针对性推理。
如果ABC会根据输入变化,那么就无法将SSM转化为标准卷积过程。于是Mamba的motivation就是:
- 参数化矩阵:AB矩阵变为输入数据驱动
- 硬件感知的并行运算算法
- 更简单的SSM结构
先接晒下上面各个数据维度(B,L,N)的含义:
- B: Batch Size
- L:序列长度
- D:每个时间step的序列特征维度
- N:隐状态的维度
不直接把AB参数化成(B,L,D,N)的尺寸,一方面会造成参数量增大,另一方面,通过上下文推导带有delta项的运算,因此只需要参数化成如上的尺寸即可实现AB的参数化。
硬件感知算法:算法并行化(选择性扫面算法)
定义了一种 Add 运算符,假设运算操作顺序与关联矩阵A无关,会发现每个xt乘以的矩阵都源于xt本身(矩阵BC是x通过线性层得到的),该运算满足交换律和结合律,虽然不是卷积但也是一种并行运算。
举个例子,对于y3
y_3=C\bar{A_3}\bar{A_2}\bar{A_1}\bar{B_0}x_0+C\bar{A_3}\bar{A_2}\bar{B_1}x_1+C\bar{A_3}\bar{B_2}x_2+\bar{B_3}x_3
更简单的SSM结构:一路是数据过Conv和SSM Block,另一路充当一个门控,对数据进行筛选