Neural ODE的介绍与数学推导

引入
从ResNet模型来看
ResNet, RNN, normalizing flows等模型都通过隐向量构建变换。在ResNet中,隐向量
我们可以从中(主要是ResNet)概括出前向传播中隐藏状态的更新模式,为
上式可以变形为差分方程的形式:
从ODE来看
给定一个常微分方程
理论上是可以的,对方程两边积分,可得
其中难以计算解析解的只有定积分项,可以用数值积分方法(也可以叫做ODE Solver)给出近似解。
Euler method就是一种朴素的数值积分方法,可以用它引入NODEs,因此我们简单介绍这种近似方法。Euler method从初始点
每次循环中已知当前点
数学表示Euler method的递推公式为
理论上,根据此递推关系可以逐步求出所有
当然,近似的误差会被逐步放大,因此Euler method不够精确,实际更常用Runge-Kutta method等求数值积分。
引入Neural ODE
ResNet中隐藏状态的更新公式
Euler method的递推公式
对应的微分方程
可以发现,(1)和(2)这两个公式十分相似。 式(1)的隐藏状态
这种关联说明:“ResNet, RNN等模型的前向传播”类似于“使用Euler method近似求解ODE”。
进一步地,我们会想到,(前向传播中的)模型是否可以视为待求解的微分方程?求解微分方程的过程是否可以视为模型的前向传播?这种想法很有吸引力,因为求解ODE的方法很多,我们可以换用其他ODE Solver,或许可以增加前向传播的效率。进一步地,如果前向传播可视为求微分方程数值解,能否设计一种基于ODE数值方法的反向传播方法?
Nerual ODE就是在这种思想下被设计出来的,它使用ODE Solver的方法进行前向和反向传播。更有用的特点是,它可以更好表征连续(时间/空间)观测值,正如微分方程(3)中的
Neural ODE形式上为
其中,
- 从ODE的角度看,Neural ODE将微分方程(3)中的
替换为神经网络,以 构成的集合为数据集进行训练。
实际上,模型 建模的是 的梯度场(参照(3)式),也即:Neural ODE对数据的梯度场参数化为神经网络。 - 从神经网络的角度看,它似乎结合了ResNet和RNN的部分想法。感性地看,
- ResNet的残差块是离散的,Neural ODE尝试把它们“连续化”,这个“连续化”有两个方面:
- 引入无穷个层(与
的某个子集一一映射),且层间“距离”很小; - 不同层之间的差异很小;
- 引入无穷个层(与
- 不同的层之间共享参数,这类似于RNN的思想。
“连续化”是针对输入数据所在空间的,这使得不同时间间隔的数据可作为训练数据。RNN能处理离散的时间序列,例如把一句话作为时间序列数据,“时间”就意味着单词(token)是句子的第几个单词,而RNN难以处理连续时间序列,例如“通过小球的时空坐标求解加速度场”这个问题,我们只能等间隔地采样 数据对来训练模型。Neural ODE就可以处理这个问题(求加速度场),并且不需要等间隔采样,无论如何采样 作为数据集,模型都可以表征。
- ResNet的残差块是离散的,Neural ODE尝试把它们“连续化”,这个“连续化”有两个方面:
Neural ODE
Neural ODE对输入值的梯度进行建模
其中,
接下来最重要的问题是,如何训练和测试Neural ODE,也就是在问,Neural ODE的forward process和backward process是怎样的?
forward process
forward process的输入输出:(2.0:1)中
既然
实用场景下,会选择能够利用GPU进行加速的ODE Solver进行前向计算。
这里是更具体一点的推导(其实和前面的推导几乎一样):
假设我们已知
训练数据集中任意选数据对都可以作为
backward process
首先列出Neural ODE中隐藏状态更新公式(将(5)中的
backward process的输入输出:给定损失函数
Back Propagation Through Time
第一种方法,也是最容易想到的,按照一般神经网络损失的反向传播(Back Propagation, BP)算法推导:
将损失函数套到(2.2:1)上得到
使用ODE Solver能够求出
既然损失函数可以求出,那么按照ODE Solver内部的autodiff(自动微分)就可以计算梯度
同样类似于RNN,BPTT算法的问题在于:如果输入的时间序列数据很长,就会有很多个
另外,这种方法显然没有利用ODE本身的性质,只是把它视为普通的神经网络,根据forward process去推backward process而已。
Adjoint Method
而第二种方法,称作伴随方法,利用了ODE本身的性质,通过求解另一个辅助ODE(称为伴随方程)来计算梯度。
将时间序列数据记作
文献中提到的
似乎有多种含义,它有时表示在 时刻的损失 ,或记作 ,表示从初始时刻 到 时刻(累积)的损失(一般人为规定 ,因为 时没有发生反向传播,不认为产生损失),与 有关;有时表示关于参数的总损失 ,与 无关,可以表示为 ,即从 反向传播到最初时刻 的总损失。根据误差梯度更新参数时,使用的 是与 无关的总损失 。
为forward process的ODE
能证明
接下来我们将扩展(2)式中的变量,以构造一个包含更多信息的伴随方程(后文的(6)式),将关心的项组合到一个变量
计算Jacobian matrix(自变量
定义
> 第二分量中,
参照伴随状态定义(1)和伴随方程(2),不难发现:由于
对应于扩展后变量(
上式等号右侧带入(5)式(i.e.,
其中每个分量都相等,即
其中,第一个等式就是扩展前的伴随方程,即(2)式;第二个等式(8)可用于求
将(8)两侧对整个区间
(10)式表示损失对参数的梯度,用于更新神经网络参数
(10)中
是vector-Jacobian product(VJP),计算时不需要显式求出Jacobian matrix,实际计算框架中对此计算有优化。
论文中给出的adjoint method流程(由伪代码改写)如下:
(输入:网络参数
1. 计算损失对时间梯度
2. 定义初始扩展状态
3. 根据
4. 反向时间求解ODE:
5. 输出:
接下来根据这些值更新参数即可,其中
Adjoint Method的证明
待证明命题
已知
则下面的等式成立
证明
(证明中的向量表示为行向量)
由于隐状态连续,可以写出时间上相差
由标准神经网络链式法则
写出连续隐状态下的链式法则
借助
接下来从(2)的左侧开始推导
接下来把
将(8)带入(7)中得到
最后一步利用
(9)即待证明的(2)式,证毕。
Neural ODE的特点
由于对所有时间序列数据都使用同一网络(类似于RNN),Neural ODE的参数量非常少。
在forward过程中不需要保存中间结果(只需保存最终时间的结果),故显存占用少,但backward需要重新计算(与forward计算量接近),因此时间开销更大。
所有隐藏状态
后记
论文:
R. T. Q. Chen, Y. Rubanova, J. Bettencourt, and D. Duvenaud, “Neural ordinary differential equations,” Dec. 14, 2019, arXiv: arXiv:1806.07366. doi: 10.48550/arXiv.1806.07366.
参考:
Neural ODE的引入方式参考youtube上Steve Brunton的视频
主要思路参照知乎上的这篇文章,并且根据原论文更详细地解释了一些数学推导
推荐:本Blog采用与原论文相同的方法推导Adjoint Method,而这篇Blog使用Lagrange乘子推导
关于代码层面的vector-Jacobian product计算优化,可参阅此Blog,与上一篇作者相同
- Title: Neural ODE的介绍与数学推导
- Author: Aroma
- Created at : 2025-06-26 00:00:00
- Updated at : 2025-07-08 00:00:00
- Link: https://recynie.github.io/2025-06-26/neural-ode/
- License: This work is licensed under CC BY-NC-SA 4.0.