0% found this document useful (0 votes)
213 views

Neural Ordinary Differential Equations: Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud

Neural ordinary differential equations (ODEs) provide a novel framework for modeling temporal data and continuous normalizing flows using neural networks parameterized as ODEs. By interpreting deep learning models like ResNets as discretizations of ODEs, the framework allows adapting step sizes and leveraging black-box ODE solvers during training for more accurate and memory-efficient computation compared to fixed discretizations. The framework enables continuous-time modeling using only constant memory and adapts computation to instance complexity.

Uploaded by

Gabriel L
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
213 views

Neural Ordinary Differential Equations: Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud

Neural ordinary differential equations (ODEs) provide a novel framework for modeling temporal data and continuous normalizing flows using neural networks parameterized as ODEs. By interpreting deep learning models like ResNets as discretizations of ODEs, the framework allows adapting step sizes and leveraging black-box ODE solvers during training for more accurate and memory-efficient computation compared to fixed discretizations. The framework enables continuous-time modeling using only constant memory and adapts computation to instance complexity.

Uploaded by

Gabriel L
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 42

Neural Ordinary

Differential Equations

Ricky T. Q. Chen*, Yulia Rubanova*, Jesse Bettencourt*, David Duvenaud


University of Toronto
Background: Ordinary Differential Equations (ODEs)
- Model the instantaneous change of a state.

(explicit form)

- Solving an initial value problem (IVP) corresponds to integration.

(solution is a trajectory)

- Euler method approximates with small steps:


Residual Networks interpreted as an ODE Solver
- Hidden units look like:
- Final output is the composition:

Haber & Ruthotto (2017). E (2017).


Residual Networks interpreted as an ODE Solver
- Hidden units look like:
- Final output is the composition:

- This can be interpreted as an Euler


discretization of an ODE.

- In the limit of smaller steps:

Haber & Ruthotto (2017). E (2017).


Deep Learning as Discretized Differential Equations
Many deep learning networks can be interpreted as ODE solvers.
Network Fixed-step Numerical Scheme

ResNet, RevNet, ResNeXt, etc. Forward Euler


Lu et al. (2017)
Chang et al. (2018)
PolyNet Approximation to Backward Euler
Zhu et al. (2018)
FractalNet Runge-Kutta

DenseNet Runge-Kutta
Deep Learning as Discretized Differential Equations
Many deep learning networks can be interpreted as ODE solvers.
Network Fixed-step Numerical Scheme

ResNet, RevNet, ResNeXt, etc. Forward Euler


Lu et al. (2017)
Chang et al. (2018)
PolyNet Approximation to Backward Euler
Zhu et al. (2018)
FractalNet Runge-Kutta

DenseNet Runge-Kutta

But:
(1) What is the underlying dynamics?
(2) Adaptive-step size solvers provide better error handling.
“Neural” Ordinary Differential Equations

Instead of y = F(x),
“Neural” Ordinary Differential Equations

Instead of y = F(x), solve y = z(T)


given the initial condition z(0) = x.

Parameterize
“Neural” Ordinary Differential Equations

Instead of y = F(x), solve y = z(T)


given the initial condition z(0) = x.

Parameterize

Solve the dynamic using any


black-box ODE solver.
- Adaptive step size.
- Error estimate.
- O(1) memory learning.
Backprop without knowledge of the ODE Solver
Ultimately want to optimize some loss
Backprop without knowledge of the ODE Solver
Ultimately want to optimize some loss

Naive approach: Know the solver. Backprop through the solver.


- Memory-intensive.
- Family of “implicit” solvers perform inner optimization.
Backprop without knowledge of the ODE Solver
Ultimately want to optimize some loss

Naive approach: Know the solver. Backprop through the solver.


- Memory-intensive.
- Family of “implicit” solvers perform inner optimization.
Our approach: Adjoint sensitivity analysis. (Reverse-mode Autodiff.)
- Pontryagin (1962).
+ Automatic differentiation.
+ O(1) memory in backward pass.
Continuous-time Backpropagation
Residual network. Adjoint method. Define:

Forward:

Backward:

Params:
Continuous-time Backpropagation
Residual network. Adjoint method. Define:

Forward: Forward:

Backward:

Params:
Continuous-time Backpropagation
Residual network. Adjoint method. Define:

Forward: Forward:

Backward: Backward:
Adjoint State Adjoint DiffEq

Params:
Continuous-time Backpropagation
Residual network. Adjoint method. Define:

Forward: Forward:

Backward: Backward:
Adjoint State Adjoint DiffEq

Params: Params:
A Differentiable Primitive for AutoDiff

Forward:

Backward:
A Differentiable Primitive for AutoDiff

Forward:

Backward:
A Differentiable Primitive for AutoDiff

Don’t need to store layer activations for reverse pass - just follow dynamics in
reverse!

Reversible networks (Gomez et al. 2018) also only require O(1)-memory, but
require very specific neural network architectures with partitioned dimensions.
Reverse versus Forward Cost

- Empirically, reverse
pass roughly half as
expensive as forward
pass.
-

- Adapts to instance
difficulty.
-

- Num evaluations can


be viewed as number of
layers in neural nets.

NFE = Number of Function Evaluations.


Dynamics Become Increasingly Complex

- Dynamics become
more demanding to
compute during
training.

- Adapts computation
time according to
complexity of diffeq.

In contrast, Chang et al. (ICLR 2018)


explicitly add layers during training.
Continuous-time RNNs for Time Series Modeling
- We often want arbitrary measurement times, ie. irregular time intervals.
- Can do VAE-style inference with a latent ODE.
ODEs vs Recurrent Neural Networks (RNNs)

- RNNs learn very


stiff dynamics,
have exploding
gradients.
-

- Whereas ODEs
are guaranteed
to be smooth.
Continuous Normalizing Flows
Instantaneous Change of variables (iCOV):

- For a Lipschitz continuous function


Continuous Normalizing Flows
Instantaneous Change of variables (iCOV):

- For a Lipschitz continuous function

- In other words,
Continuous Normalizing Flows
Instantaneous Change of variables (iCOV):

- For a Lipschitz continuous function

- In other words,

With an
invertible F:
Continuous Normalizing Flows
1D: 2D: Data Discrete-NF CNF
Is the ODE being correctly solved?
Stochastic Unbiased Log Density
Stochastic Unbiased Log Density

Can further reduce time complexity using stochastic estimators.

Grathwohl et al. (2019)


FFJORD - Stochastic Continuous Flows
MNIST - Model Samples CIFAR10 - Model Samples

Grathwohl et al. (2019)


Variational Autoencoders with FFJORD
ODE Solving as a Modeling Primitive
Adaptive-step solvers with O(1) memory backprop.

github.com/rtqichen/torchdiffeq

Future directions we’re currently working on:

- Latent Stochastic Differential Equations.


- Network architectures suited for ODEs.
- Regularization of dynamics to require fewer evaluations.
Co-authors:

Yulia Rubanova Jesse Bettencourt David Duvenaud

Thanks!
Extra Slides
Latent Space Visualizations
• Released an implementation of reverse-mode
autodiff through black-box ODE solvers.

• Solves a system of size 2D + K + 1.

• In contrast, forward-mode implementation


solves a system of size D^2 + KD.

• Tensorflow has Dormand-Prince-Shampine


Runge-Kutta 5(4) implemented, but uses
naive autodiff for backpropagation.
How much precision is needed?
Explicit Error Control

- More fine-grained
control than
low-precision floats.

- Cost scales with


instance difficulty.

NFE = Number of Function Evaluations.


Computation Depends on Complexity of Dynamics

- Time cost is dominated by


evaluation of dynamics f.

NFE = Number of Function Evaluations.


Why not use an ODE solver as modeling primitive?
- Solving an ODE is expensive.
Future Directions
- Stochastic differential equations and Random ODEs. Approximates stochastic
gradient descent.
- Scaling up ODE solvers with machine learning.
- Partial differential equations.
- Graphics, physics, simulations.

You might also like