Neural ODE的介绍与数学推导

Aroma

引入

从ResNet模型来看

ResNet, RNN, normalizing flows等模型都通过隐向量构建变换。在ResNet中,隐向量与它经残差块变换后的输出相加得到新的隐向量,即;在RNN中,共用类似表达,但对于所有输入的变换都相同(所有输入共享参数),.
我们可以从中(主要是ResNet)概括出前向传播中隐藏状态的更新模式,为 其中,表示时刻(时间序列中的第个值)的隐向量,
上式可以变形为差分方程的形式:

从ODE来看

给定一个常微分方程 及初值条件( 满足 ),我们能否解出任意时刻的函数值,即解出全部的
理论上是可以的,对方程两边积分,可得
其中难以计算解析解的只有定积分项,可以用数值积分方法(也可以叫做ODE Solver)给出近似解。
Euler method就是一种朴素的数值积分方法,可以用它引入NODEs,因此我们简单介绍这种近似方法。Euler method从初始点开始,使用微分的思想,沿着梯度的方向近似解出短时间后的函数值,不断循环前进,即:
每次循环中已知当前点的值。
数学表示Euler method的递推公式为 其中.
理论上,根据此递推关系可以逐步求出所有的近似值。

当然,近似的误差会被逐步放大,因此Euler method不够精确,实际更常用Runge-Kutta method等求数值积分。

引入Neural ODE

ResNet中隐藏状态的更新公式
Euler method的递推公式
对应的微分方程
可以发现,(1)和(2)这两个公式十分相似。 式(1)的隐藏状态依赖于离散的时间;式(2)中待求解的未知变量依赖于连续时间的离散采样,二者都是关于离散时间的差分方程。考虑到(2)是对(3)的近似求解方法,(1)与(2)的相似性实际上暗示(1)与(3)之间存在关联,也就是神经网络模型和常微分方程之间存在关联。
这种关联说明:“ResNet, RNN等模型的前向传播”类似于“使用Euler method近似求解ODE”。
进一步地,我们会想到,(前向传播中的)模型是否可以视为待求解的微分方程?求解微分方程的过程是否可以视为模型的前向传播?这种想法很有吸引力,因为求解ODE的方法很多,我们可以换用其他ODE Solver,或许可以增加前向传播的效率。进一步地,如果前向传播可视为求微分方程数值解,能否设计一种基于ODE数值方法的反向传播方法?
Nerual ODE就是在这种思想下被设计出来的,它使用ODE Solver的方法进行前向和反向传播。更有用的特点是,它可以更好表征连续(时间/空间)观测值,正如微分方程(3)中的,它可以表征,而是连续时间的函数。
Neural ODE形式上为
其中,表示神经网络本体,为数据(时刻的值,多个构成一组序列,作为数据集),是神经网络的参数。

  • 从ODE的角度看,Neural ODE将微分方程(3)中的替换为神经网络,以构成的集合为数据集进行训练。
    实际上,模型建模的是的梯度场(参照(3)式),也即:Neural ODE对数据的梯度场参数化为神经网络。
  • 从神经网络的角度看,它似乎结合了ResNet和RNN的部分想法。感性地看,
    • ResNet的残差块是离散的,Neural ODE尝试把它们“连续化”,这个“连续化”有两个方面:
      • 引入无穷个层(与的某个子集一一映射),且层间“距离”很小;
      • 不同层之间的差异很小;
    • 不同的层之间共享参数,这类似于RNN的思想。
      “连续化”是针对输入数据所在空间的,这使得不同时间间隔的数据可作为训练数据。RNN能处理离散的时间序列,例如把一句话作为时间序列数据,“时间”就意味着单词(token)是句子的第几个单词,而RNN难以处理连续时间序列,例如“通过小球的时空坐标求解加速度场”这个问题,我们只能等间隔地采样数据对来训练模型。Neural ODE就可以处理这个问题(求加速度场),并且不需要等间隔采样,无论如何采样作为数据集,模型都可以表征。

Neural ODE

Neural ODE对输入值的梯度进行建模
其中,表示神经网络本体,作为为数据(时刻的值,多个构成一组序列,作为数据集),是神经网络的参数。
接下来最重要的问题是,如何训练和测试Neural ODE,也就是在问,Neural ODE的forward process和backward process是怎样的?

forward process

forward process的输入输出:(2.0:1)中(的参数)已知,神经网络本身不变化,作为输入,需要计算的值。
既然已知,这其实就是一个ODE(可以与(1.2:1)式比照着看),利用数值积分方法就可以算出来。
实用场景下,会选择能够利用GPU进行加速的ODE Solver进行前向计算。

这里是更具体一点的推导(其实和前面的推导几乎一样):
假设我们已知,想知道的值。对(2.0:1)两侧积分,得到
训练数据集中任意选数据对都可以作为,然后选用合适的数值积分方法求积分项,就可以得到结果。

backward process

首先列出Neural ODE中隐藏状态更新公式(将(5)中的符号替换为,仍表示神经网络):
backward process的输入输出:给定损失函数,已知神经网络(的参数)和数据集,需要计算损失关于神经网络参数的梯度,根据这个梯度更新的值。

Back Propagation Through Time

第一种方法,也是最容易想到的,按照一般神经网络损失的反向传播(Back Propagation, BP)算法推导:
将损失函数套到(2.2:1)上得到
使用ODE Solver能够求出的值,于是将这一项替代为,体现ODE Solver接收作为输入,将的值(也即ODE 的解)输出。
既然损失函数可以求出,那么按照ODE Solver内部的autodiff(自动微分)就可以计算梯度,自然可以反向传播更新网络的参数。这类似于RNN中使用的BPTT算法。
同样类似于RNN,BPTT算法的问题在于:如果输入的时间序列数据很长,就会有很多个,再加上ODE Solver可能很复杂,那么forward process得到的计算图(computation graph)很复杂。由于BP要保存forward process中的部分计算结果(也就是activation)用于backward process,复杂的计算图就要求保存更多的activation,也就带来了更大的显存消耗。
另外,这种方法显然没有利用ODE本身的性质,只是把它视为普通的神经网络,根据forward process去推backward process而已。

Adjoint Method

而第二种方法,称作伴随方法,利用了ODE本身的性质,通过求解另一个辅助ODE(称为伴随方程)来计算梯度。
将时间序列数据记作,用表示任意时间,损失函数关于参数的总梯度为,需要用它进行backward process。

文献中提到的似乎有多种含义,它有时表示在时刻的损失,或记作,表示从初始时刻时刻(累积)的损失(一般人为规定,因为时没有发生反向传播,不认为产生损失),与有关;有时表示关于参数的总损失,与无关,可以表示为,即从反向传播到最初时刻的总损失。根据误差梯度更新参数时,使用的是与无关的总损失

为forward process的ODE 创建一个伴随状态(adjoint state)
能证明可以用一个新的ODE来表示,称作伴随方程(adjoint equation):(证明过程见此处

接下来我们将扩展(2)式中的变量,以构造一个包含更多信息的伴随方程(后文的(6)式),将关心的项组合到一个变量中。相应地,将扩展为,原来的的梯度,现在我们希望表示的梯度,即定义(并化简)
计算Jacobian matrix(自变量省略)

定义三个变量

> 第二分量中,无关,实际上就是前文提到的与时间相关损失上的梯度,可用于更新参数。

参照伴随状态定义(1)和伴随方程(2),不难发现:由于成立,的伴随状态,伴随方程(6)也因此成立。
对应于扩展后变量()的伴随方程为

上式等号右侧带入(5)式(i.e., )和(4)式化简,等号左侧将微分作用于三个分量,于是(6)式可化简为

其中每个分量都相等,即



其中,第一个等式就是扩展前的伴随方程,即(2)式;第二个等式(8)可用于求
将(8)两侧对整个区间积分,并令,得到损失的总梯度

(10)式表示损失对参数的梯度,用于更新神经网络参数

(10)中是vector-Jacobian product(VJP),计算时不需要显式求出Jacobian matrix,实际计算框架中对此计算有优化。

论文中给出的adjoint method流程(由伪代码改写)如下:
(输入:网络参数,初始时间,停止时间,最终状态,损失对最终状态的梯度
1. 计算损失对时间梯度
2. 定义初始扩展状态
3. 根据计算(计算vector-Jacobian乘积)
4. 反向时间求解ODE:
5. 输出:
接下来根据这些值更新参数即可,其中是最重要的,用于更新神经网络参数,其他梯度提供了损失对输入、初始时间、终止时间的敏感度。

Adjoint Method的证明

待证明命题
已知 ,定义伴随状态
则下面的等式成立

证明
(证明中的向量表示为行向量)
由于隐状态连续,可以写出时间上相差的隐状态的关系:

由标准神经网络链式法则

写出连续隐状态下的链式法则

借助定义(1),将链式法则用伴随状态表示为

接下来从(2)的左侧开始推导

接下来把处泰勒展开(一阶项),需要用到(3)

将(8)带入(7)中得到

最后一步利用的连续性,这需要假设“一阶连续可导”(通过定义(1)可知连续)。
(9)即待证明的(2)式,证毕。

Neural ODE的特点

由于对所有时间序列数据都使用同一网络(类似于RNN),Neural ODE的参数量非常少。
在forward过程中不需要保存中间结果(只需保存最终时间的结果),故显存占用少,但backward需要重新计算(与forward计算量接近),因此时间开销更大。
所有隐藏状态的shape相同,可能限制使用场景。

后记

论文:
R. T. Q. Chen, Y. Rubanova, J. Bettencourt, and D. Duvenaud, “Neural ordinary differential equations,” Dec. 14, 2019, arXiv: arXiv:1806.07366. doi: 10.48550/arXiv.1806.07366.
参考:

  • Neural ODE的引入方式参考youtube上Steve Brunton的视频

  • 主要思路参照知乎上的这篇文章,并且根据原论文更详细地解释了一些数学推导
    推荐:

  • 本Blog采用与原论文相同的方法推导Adjoint Method,而这篇Blog使用Lagrange乘子推导

  • 关于代码层面的vector-Jacobian product计算优化,可参阅此Blog,与上一篇作者相同

  • Title: Neural ODE的介绍与数学推导
  • Author: Aroma
  • Created at : 2025-06-26 00:00:00
  • Updated at : 2025-07-08 00:00:00
  • Link: https://recynie.github.io/2025-06-26/neural-ode/
  • License: This work is licensed under CC BY-NC-SA 4.0.