0% found this document useful (0 votes)
44 views11 pages

009 Opening The Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics

Uploaded by

Nathaniel Saura
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
44 views11 pages

009 Opening The Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics

Uploaded by

Nathaniel Saura
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 11

Opening the Blackbox: Accelerating Neural Differential Equations by

Regularizing Internal Solver Heuristics

Avik Pal 1 2 Yingbo Ma 2 Viral Shah 2 Christopher Rackauckas 2 3 4 5

Abstract
Democratization of machine learning requires ar-
arXiv:2105.03918v2 [cs.LG] 4 Feb 2022

chitectures that automatically adapt to new prob-


lems. Neural Differential Equations (NDEs) have
emerged as a popular modeling framework by re-
moving the need for ML practitioners to choose
the number of layers in a recurrent model. While
we can control the computational cost by choos-
ing the number of layers in standard architectures,
in NDEs the number of neural network evalua-
tions for a forward pass can depend on the num-
ber of steps of the adaptive ODE solver. But, can Figure 1. Training and Prediction Performance of Regularized
we force the NDE to learn the version with the NDEs We obtain an average training and prediction speedup of
least steps while not increasing the training cost? 1.45x and 1.84x respectively for our best model on supervised
Current strategies to overcome slow prediction re- classification and time series problems.
quire high order automatic differentiation, leading
to significantly higher training time. We describe state-of-the-art equation solvers can be used to
a novel regularization method that uses the inter- enhance machine learning.
nal cost heuristics of adaptive differential equation
solvers combined with discrete adjoint sensitivi- 1. Introduction
ties to guide the training process towards learning
NDEs that are easier to solve. This approach How many hidden layers should you choose in your recur-
opens up the blackbox numerical analysis behind rent neural network? Chen et al. (2018) showed that the
the differential equation solver’s algorithm and answer could be found automatically by using a continuous
directly uses its local error estimates and stiffness reformulation, the neural ordinary differential equation, and
heuristics as cheap and accurate cost estimates. allowing an adaptive ODE solver to effectively choose the
We incorporate our method without any change number of steps to take. Since then the idea was generalized
in the underlying NDE framework and show that to other domains such as stochastic differential equations
our method extends beyond Ordinary Differen- (Liu et al., 2019; Rackauckas et al., 2020b) but one fact re-
tial Equations to accommodate Neural Stochastic mained: solving a neural differential equation is expensive,
Differential Equations. We demonstrate how our and training a neural differential equation is even more so.
approach can halve the prediction time and, un- In this manuscript we show a generally applicable method
like other methods which can increase the training to force the neural differential equation training process to
time by an order of magnitude, we demonstrate choose the least expensive option. We open the blackbox
similar reduction in training times. Together this and show how using the numerical heuristics baked inside of
showcases how the knowledge embedded within these sophisticated differential equation solver codes allows
for identifying the cheapest equations without requiring
1
Indian Institute of Technology Kanpur 2 Julia Computing extra computation.
3
Massachusetts Institute of Technology 4 Pumas AI 5 University
of Maryland Baltimore. Correspondence to: Avik Pal Our main contributions include:
<[email protected]>, Christopher Rackauckas <crack-
[email protected]>.
• We introduce a novel regularization scheme for neu-
th
Proceedings of the 38 International Conference on Machine ral differential equations based on the local error es-
Learning, PMLR 139, 2021. Copyright 2021 by the author(s). timates and stiffness estimates. We observe that by
Opening the Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics

model the ODE dynamics dz(t) dt = fθ (z(t), t). Using adap-


tive time stepping allows the model to operate at a variable
continuous depth depending on the inputs. Removal of
the fixed depth constraint of Residual Networks provides a
more expressive framework and offer several advantages in
problems like density estimation (Grathwohl et al., 2018),
irregularly spaced time series problems (Rubanova et al.,
2019), etc.

2.2. Neural Stochastic Differential Equations


Stochastic Differential Equations (SDEs) couple the effect
Figure 2. Error and Stiffness Regularization Keeps Accuracy. of noise to a deterministic system of equations. SDEs are
We show the fits of the unregularized/regularized Neural ODE popularly used to model fluctuating stock prices, thermal
variants on the Sprial equation. However, the unregularized variant fluctuations in physical systems, etc. In this paper, we only
requires 1083.0 ± 57.55 NFEs while the one regularized using the discuss SDEs with Diagonal Multiplicative Noise, though
stiffness and error estimates requires only 676.2 ± 68.20 NFEs, our method trivially extends to all other forms of SDEs.
reducing prediction time by nearly 50%.
Liu et al. (2019) propose an extension to Neural ODEs by
stochastic noise injection in the form of Neural SDEs. Neu-
white-boxing differential equation solvers to leverage ral SDEs jointly train two neural networks fθ and gφ , such
pre-computed statistics about the neural differential that, the dynamics dz(t) = fθ (z(t), t)dt + gφ (z(t), t)dW .
equations, we can obtain faster training and prediction Stochastic Noise Injection regularize the training of contin-
time while having a minimal effect on testing metrics. uous neural models and achieves significantly better robust-
ness and generalization performance.
• We compare our method with various regularization
schemes (Kelly et al., 2020; Behl et al., 2020), which
2.3. Regularizing Neural ODEs for Speed
often use higher order derivatives and are difficult to
incorporate within existing systems. We empirically Given the map z(0) → z(1) does not uniquely define the
show that regularization using cheap statistics can lead dynamics, it is possible to regularize the training process to
to as efficient predictions as the ones requiring higher learn differential equations that can be solved using fewer
order automatic differentiation (Kelly et al., 2020; Fin- evaluations of fθ . In the case of continuous normalizing
lay et al., 2020) without the increased training time. flows (CNF), the ordinary differential equation:
dz(t)
• We release our code1 , implemented using the Julia = fθ (z(t), t) (1)
dt
Programming Language (Bezanson et al., 2017) and  
dy(t) dfθ
SciML Software Suite (Rackauckas et al., 2019), with = −tr (2)
the intention of wider adoption of the proposed meth- dt dz
ods in the community. where y(t) evolves a log-density (Chen et al., 2018). The
FFJORD method improves the speed of CNF evaluations
2. Background by approximating tr( dfdz ) via the Hutcheson trace estimator,
θ

dfθ T dfθ
i.e. tr( dz ) = E[ dz ] where  ∼ N (0, 1) (Hutchinson,
2.1. Neural Ordinary Differential Equations 1989; Grathwohl et al., 2018). Subsequent research showed
Ordinary Differential Equations (ODEs) are used to model that this trace estimator could be used to regularize the
Frobenius norm of the Jacobian k df T dfθ
dz k =  dz (Finlay
θ
the instantaneous rate of change ( dz(t) dt ) of a state z(t). Ini- T dfθ
tial Value Problems (IVPs) are a class of ODEs that involve et al., 2020). While  dz is computationally expensive as it
finding the state at a later time tR1 , given the value z0 at time requires a reverse mode automatic differentiation evaluation
t in the model (leading to higher order differentiation), in the
t0 . This state, z(t1 ) = z0 + t01 fθ (z(t), t)dt, generally
specific case of FFJORD this term is already required and
cannot be computed analytically and requires numerical
thus this estimate is a computationally-free regularizer.
solvers. Lu et al. (2018) observed the similarity between
fixed time-step discretization of ODEs and Residual Neural It was later shown that this form of regularization can be
Networks (He et al., 2015). Chen et al. (2018) proposed extended beyond FFJORD by using higher order automatic
the Neural ODE framework which use neural networks to differentiation (Kelly et al., 2020). This was done by reg-
1
https://github.com/avik-pal/RegNeuralODE. ularizing a heuristic for the local error estimate, namely
Rt K
jl RK (θ) = t0f k d dtz(t) 2
K k2 dt. The authors showed Taylor-
Opening the Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics

mode automatic differentiation improves the efficiency of control scheme (P-control) proposes hnew = ηqh, while
calculating this estimator to a O(k 2 ) cost where k is the a standard PI-controller of explicit adaptive Runge-Kutta
order of derivative that is required, though this still implies methods can be shown to be equivalent to using:
that obtaining the 5 derivatives requires is a significant com-
α
putational increase. In fact, the authors noted that “when hnew = ηqn−1 qnβ h (6)
we train with adaptive solvers we do not improve overall
training time”, and in fact giving a 1.7x slower training time. where η is the safety factor, qn−1 denotes the error pro-
In this manuscript we show that this is all the way up to 10x portion of the previous step, and (α, β) are the tunable
on the PhysioNet challenge problem. PI gain hyperparameters (Wanner & Hairer, 1996). Simi-
lar embedded methods error estimation schemes have also
Here we show how to arrive at a similar regularization been derived for stochastic Runge-Kutta integrators of SDEs
heuristic that is applicable to all neural ODE applications (Rackauckas & Nie, 2017; 2020).
with suitable adaptive ODE solvers and requires no higher
order automatic differentiation. We will show that this 2.5. Stiffness Estimation
form of regularization is able to significantly improve train-
ing times and generalizes to other architectures like neural While there is no precise definition of stiffness, the defini-
SDEs. tion used in practice is “stiff equations are problems for
which explicit methods don’t work” (Wanner & Hairer,
2.4. Adaptive Time Stepping using Local Error 1996; Shampine & Gear, 1979). A simplified stiffness index
Estimates is given by:
S = maxkRe(λi )k (7)
Runge-Kutta Methods (Runge, 1895; Kutta, 1901) are
widely used for numerically approximating the solutions of where λi are the eigenvalues of the local Jacobian matrix.
ordinary differential equations. They are given by a tableau We note that various measures of stiffness have been intro-
of coefficients {A, c, b} where the stages s are combined to duced over the years, all being variations of conditioning of
produce an estimate for the update at t + h: the pseudospectra (Shampine & Thompson, 2007; Higham
! & Trefethen, 1993). The difficulty in defining a stiffness
X s metric is that in each case, some stiff systems like the classic
ks = f t + cs h, z(t) + asi ki Robertson chemical kinetics or excited Van der Pol equa-
i=1 tion may violate the definition, meaning all such definitions
s
(3)
X are (useful) heuristics. In particular, it was shown that for
z(t + h) = z(t) + h bi ki
explicit Runge-Kutta methods satisfying cx = cy for some
i=1
internal step, the term
For adaptivity, many Runge-Kutta methods include an al- Ps Ps
ternative linear combiner b̃i such that z̃(t + h) = z(t) + f (t + cx h, i=1 axi ) − f (t + cy h, i=1 ayi )
Ps kλk ≈ Ps Ps
h i=1 b̃i ki gives rise to an alternative solution, typically i=1 axi − i=1 ayi
with one order less convergence (Wanner & Hairer, 1996; (8)
Fehlberg, 1968; Dormand & Prince, 1980; Tsitouras, 2011). serves as an estimate to S (Shampine, 1977). Since each of
A classic result from Richardson extrapolation shows that these terms are already required in the Runge-Kutta updates
E = kz̃(t + h) − z(t + h)k is an estimate of the local trunca- of Equation 3, this gives a computationally-free estimate.
tion error (Ascher & Petzold, 1998; Hairer et al., 1993). The This estimate is thus found throughout widely used explicit
goal of adaptive step size methods is to choose a maximal Runge-Kutta implementations, such as by the dopri method
step size h for which this error estimate is below user re- (found in suites like SciPy and Octave) to automatically exit
quested error tolerances. Given the absolute tolerance atol when stiffness is detected (Wanner & Hairer, 1996), and
and relative tolerance rtol, the solver satisfies the following by switching methods which automatically change explicit
constraint for determining the time stepping: Runge-Kutta methods to methods more suitable for stiff
equations (Rackauckas & Nie, 2019).
E ≤ atol + max(|z(t)|, |z(t + h)|) · rtol (4)
3. Method
The proportion of the error against the tolerance is thus:
3.1. Regularizing Local Error and Stiffness Estimates
E
q= (5) Section 2.4 describes how larger local error estimates E
atol + max(|zn |, |zn+1 |) · rtol
lead to reduced step sizes and thus a higher overall cost in
If q < 1 then the proposed time step h is accepted, else it the neural ODE training and predictions. Given this, we
is rejected and reduced. In either case, a proportional error propose regularizing the neural ODE training process by the
Opening the Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics

total local error in order to learn neural ODEs with as large intensive than some forms of the continuous adjoint, we
step sizes as possible. Thus we define the regularizing term: note that checkpointing methods can be used to reduce the
X peak memory (Dauvergne & Hascoët, 2006). We note that
RE = Ej |hj | (9) this is equivalent to backpropagation of a fixed time step
j discretization if the step sizes are chosen in advance, and
summing over j the time steps of the solution. This was verify in the example code that no additional overhead is
done by accumulating the Ej from the internals of the time introduced.
stepping process at the end of each step. We note that this is
similar to the regularization proposed in (Kelly et al., 2020), 4. Experiments
namely: Z t1 K
d z(t) In this section, we consider the effectiveness of regularizing
RK = dt (10) Neural Differential Equations (NDEs) on their training and
t0 dtK
prediction timings. We consider the following baselines
where integrating over the K th derivatives is proportional while evaluating our models:
to the principle (largest) truncation error term of the Runge-
Kutta method (Hairer et al., 1993). However, this formu- 1. Vanilla Neural (O/S)DE with discrete sensitivities.
lation requires high order automatic differentiation (which
then is layered with reverse-mode automatic differentiation) 2. STEER: Temporal Regularization for Neural ODE
which can be an expensive computation (Zhang et al., 2008) models by stochastic sampling of the end time during
while Equation 9 requires no differentiation. training (Behl et al., 2020).
Similarly, the stiffness estimates at each step can be summed 3. TayNODE: Regularizing the K th order derivatives of
as: X the Neural ODEs (Kelly et al., 2020)2 .
RS = Sj (11)
j We test our regularization on four tasks – supervised image
giving a computational heuristic for the total stiffness of the classification (Section 4.1.1) and time series interpolation
equation. Notably, both of these estimates Ej and Sj are (Section 4.1.2) using Neural ODE, and fitting Neural SDE
already computed during the course of a standard explicit (Section 4.2.1) and supervised image classification using
Runge-Kutta solution, making the forward pass calculation Neural SDE (Section 4.2.2). We use DiffEqFlux (Rack-
of the regularization term computationally free. auckas et al., 2019) and Flux (Innes et al., 2018) for our
experiments.
3.2. Adjoints of Internal Solver Estimates
Ps 4.1. Neural Ordinary Differential Equations
Notice that Ej = i=1 (bi − b˜i )ki cannot be constructed di-
rectly from the z(tj ) trajectory of the ODE’s solution. More In the following experiments, we use a Runge Kutta 5(4)
precisely, the ki terms are not defined by the continuous solver (Tsitouras, 2011) with absolute and relative toler-
ODE but instead by the chosen steps of the solver method. ances of 1.4 × 10−8 to solve the ODEs. To measure the
Continuous adjoint methods for neural ODEs (Chen et al., prediction time, we use a test batch size equal to the training
2018; Zhuang et al., 2021) only define derivatives in terms batch size.
of the ODE quantities. This is required in order exploit
properties such as allowing different steps in reverse and re- 4.1.1. S UPERVISED C LASSIFICATION
versibility for reduced memory, and in constructing solvers
Training Details We train a Neural ODE and a Linear Clas-
requiring fewer NFEs (Kidger et al., 2020). Indeed, com-
sifier to map flattened MNIST Images to their corresponding
puting the adjoint of each stage variable ki can be done,
labels. Our model uses a two layered neural network fθ1 ,
but is known as discrete sensitivity analysis and is known
as the ODE dynamics, followed by a linear classifier gθ2 ,
to be equivalent to automatic differentiation of the solver
identical to the architecture used in Kelly et al. (2020).
(Zhang & Sandu, 2014). Thus to calculate the derivative of
the solution simultaneously to the derivatives of the solver zθ1 (x, t) = tanh(W1 [x; t] + B1 ) (12)
states, we used direct automatic differentiation of the dif-
2
ferential equation solvers for performing the experiments We use the original code formulation of the TayNODE in order
(Innes, 2018). We note that discrete adjoints are known to to ensure usage of the specially-optimized Taylor-mode automatic
be more stable than continuous adjoints (Zhang & Sandu, differentiation technique (Bettencourt et al., 2019) in the training
process. Given the large size of the neural networks, most of the
2014) and in the context of neural ODEs have been shown to compute time lies in optimized BLAS kernels which are the same
stabilize the training process leading to better fits (Gholami in both implementations, meaning we do not suspect library to be
et al., 2019; Onken & Ruthotto, 2020). While more memory a major factor in timing differences beyond the AD specifics.
Opening the Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics

Method Train Accuracy (%) Test Accuracy (%) Train Time (hr) Prediction Time (s) NFE
Vanilla NODE 100.0 ± 0.00 97.94 ± 0.02 0.98 ± 0.03 0.094 ± 0.010 253.0 ± 3.46
STEER 100.0 ± 0.00 97.94 ± 0.03 1.31 ± 0.07 0.092 ± 0.002 265.0 ± 3.46
TayNODE 98.98 ± 0.06 97.89 ± 0.00 1.19 ± 0.07 0.079 ± 0.007 080.3 ± 0.43
SRNODE (Ours) 100.0 ± 0.00 98.08 ± 0.15 1.24 ± 0.06 0.094 ± 0.003 259.0 ± 3.46
ERNODE (Ours) 99.71 ± 0.28 97.32 ± 0.06 0.82 ± 0.02 0.060 ± 0.001 177.0 ± 0.00
STEER + SRNODE 100.0 ± 0.00 97.88 ± 0.06 1.55 ± 0.27 0.101 ± 0.009 275.0 ± 12.5
STEER + ERNODE 99.91 ± 0.02 97.61 ± 0.11 1.37 ± 0.11 0.086 ± 0.018 197.0 ± 9.17
SRNODE + ERNODE 99.98 ± 0.03 97.77 ± 0.05 1.37 ± 0.04 0.081 ± 0.006 221.0 ± 17.3

Table 1. MNIST Image Classification using Neural ODE Using ERNODE obtains a training and prediction speedup of 16.33% and
37.78% respectively, at only 0.6% reduced prediction accuracy. SRNODE doesn’t help in isolation but is effective when combined with
ERNODE to reduce the prediction time by 14.44% while incurring a reduced test accuracy of only 0.17%.

training improvement but there is a minor improvement in


prediction time. For the TayNODE baseline, we train the
model with a reduced batch size of 1004 , λ = 3.02 × 10−3 ,
and regularizing 3rd order derivatives.
Results Figure 3 visualizes the training accuracy and num-
ber of function evaluations over training. Table 1 summa-
rizes the metrics from the trained baseline and proposed
models – Error Estimate Regularized Neural ODE (ERN-
ODE) and Stiffness Regularized Neural ODE (SRNODE).
Additionally, we perform ablation studies by composing
various regularization strategies.

4.1.2. T IME S ERIES I NTERPOLATION


Training Details We use the Latent ODE (Chen et al., 2018)
model with RNN encoder to learn the trajectories for ICU
Patients for Physionet Challenge 2012 Dataset (Silva et al.,
2012). We use the preprocessed data provided by Kelly
Figure 3. Number of Function Evaluations and Training Accu-
racy for Supervised MNIST Classification Regularizing using
et al. (2020) to ensure consistency in results. For every
ERNODE is the most consistent way to reduce the overall number independent run, we perform an 80 : 20 split of the data for
of function evaluations. Using SRNODE alongside ERNODE training and evaluation.
stabilizes the training at the cost of increased prediction time. Our model architecture is similar to the encoder-decoder
models used in Rubanova et al. (2019). We use a 20-
fθ1 (x, t) = tanh(W2 [zθ1 (x, t); t] + B2 ) (13)
dimensional latent state and a 40-dimensional hidden state
gθ2 (x, t) = σ(W3 x + B3 ) (14) for the recognition model. Our ODE dynamics is given by a
4-layered neural network with 50 units and tanh activation.
where the parameters W1 ∈ R100×785 , B1 ∈ R100 , W2 ∈
We train our models for 300 epochs with a batchsize of 512
R784×101 , B2 ∈ R784 , W3 ∈ R10×784 , and B3 ∈ R10 . We
and using Adamax (Kingma & Ba, 2014) with a learning
use a batch size of 512 and train the model for 75 epochs
rate of 0.01 and an inverse decay of 10−5 . We minimize the
using Momentum (Qian, 1999) with learning rate of 0.1 and
negative log likelihood of the predictions and perform KL
mass of 0.9, and a learning rate inverse decay of 10−5 per
annealing with a coefficient of 0.99.
iteration. For Error Estimate Regularization, we perform
exponential annealing of the regularization coefficient from For Error Estimate Regularization, we perform exponential
100.0 to 10.0 over 75 epochs. For Stiffness Regularization, annealing of the regularization coefficient from 1000.0
P to 2
we use a constant coefficient of 0.0285. 100.0 over 300 epochs. We note that using R E = j Ej ,
P
instead of RE = j Ej |hj |, yields similar results with a
Baselines For the STEER baseline, we train the models
constant regularization coefficient of 100.0. For Stiffness
by stochastically sampling the end time point from U(T −
Regularization, we use a constant coefficient of 0.285.
b, T + b) where T = 1.0 and b = 0.53 . We observe no
4
3 Batch Size was reduced to ensure we reach a comparable
b = 0.25 was also considered but final results were compara- train/test accuracy as the other trained models.
ble
Opening the Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics

Method Train Loss (×10−3 ) Test Loss (×10−3 ) Train Time (hr) Prediction Time (s) NFE
Vanilla NODE 3.48 ± 0.00 3.55 ± 0.00 1.75 ± 0.39 0.53 ± 0.12 733.0 ± 84.29
STEER 3.43 ± 0.02 3.48 ± 0.01 1.62 ± 0.26 0.54 ± 0.06 699.0 ± 141.1
TayNODE 4.21 ± 0.02 4.21 ± 0.01 12.3 ± 0.32 0.22 ± 0.02 167.3 ± 11.93
SRNODE (Ours) 3.52 ± 1.44 3.58 ± 0.05 0.87 ± 0.09 0.20 ± 0.01 273.0 ± 0.000
ERNODE (Ours) 3.51 ± 0.00 3.57 ± 0.00 0.94 ± 0.13 0.21 ± 0.02 287.0 ± 17.32
STEER + SRNODE 3.67 ± 0.02 3.73 ± 0.02 0.89 ± 0.08 0.20 ± 0.01 271.0 ± 12.49
STEER + ERNODE 3.41 ± 0.02 3.48 ± 0.01 1.03 ± 0.25 0.24 ± 0.05 269.0 ± 33.05
SRNODE + ERNODE 3.48 ± 0.11 3.56 ± 0.03 1.12 ± 0.08 0.21 ± 0.01 263.0 ± 12.49

Table 2. Physionet Time Series Interpolation All the regularized variants of Latent ODE (except STEER) have comparable prediction
times. Additionally, the training time is reduced by 36% − 50% on using one of our proposed regularizers, while TayNODE increases the
training time by 7x. Overall, SRNODE has the best training and prediction timings while incurring an increased 0.85% test loss.

posed regularizers. We observe that SRNODE provides the


most significant speedup while ERNODE attains similar
losses at slightly higher training and prediction times.

4.2. Neural Stochastic Differential Equations


In these experiments, we use SOSRI/SOSRI2 (Rackauckas
& Nie, 2020) to solve the Neural SDEs. The wall clock
timings represent runs on a CPU.

4.2.1. F ITTING S PIRAL D IFFERENTIAL E QUATION


Training Details In this experiment, we consider training a
Neural SDE to mimic the dynamics of the Spiral Stochastic
Differential Equation with Diagonal Noise (DSDE). Spiral
DSDE is prescribed by the following equations:

du1 = −αu31 dt + βu32 dt + γu1 dW


(15)
Figure 4. Number of Function Evaluations and Training Loss du2 = −βu31 dt − αu32 dt + γu2 dW
for Physionet Time Series Interpolation Regularized and Un-
regularized variants of the model have very similar trajectories where α = 0.1, β = 2.0, and γ = 0.2. We generate
for the training loss. We do notice a significant difference in the data across 10000 trajectories at 30 uniformly spaced points
NFE plot. Using either Error Estimate Regularization or Stiffness between t ∈ [0, 1] (Figure 5). We parameterize our drift and
Regularization is able to bound the NFE to < 300, compared to diffusion functions using neural networks fθ and gφ via:
∼ 700 for STEER or unregularized Latent ODE.
fθ (x, t) = W2 tanh(W1 x3 + B1 ) + B2
(16)
Baselines For STEER Baseline, we stochastically sample gφ (x, t) = W3 x + B3
the timestep to evaluate the difference between interpolated
and ground truth data. Essentially for the interval (ti , ti+1 ), where the parameters W1 ∈ R50×2 , B1 ∈ R50 , W2 ∈
we evaluate the model at U(ti+1 − ti+12−ti , ti+1 + ti+12−ti ) R2×50 , B2 ∈ R2 , W3 ∈ R2×2 , and B3 ∈ R2 . For fitting
and compare with the truth at ti+1 . We sample end points the drift and diffusion functions to the simulated data, we
after every iteration of the model. STEER reduces the train- used a generalized method of moments loss function (Lück
ing time but has no significant effect on the prediction time. & Wolf, 2016; Jeisman, 2006). Our objective is to train
TayNODE was trained by regularizing the 2nd order deriva- these parameters to minimize the L2 distance between the
tives and a coefficient of 0.01 for 300 epochs and a batchsize mean (µ) and variance (σ 2 ) of predicted and real data. Let,
of 512. TayNODE had an exceptionally high training time µ̂i ’s and σ̂i2 ’s denote the means and variances respectively
∼ 7× compared to the unregularized baseline. of the multiple predicted trajectories.
Results Figure 4 shows the training MSE loss and the NFE 30
counts for the considered models. Table 2 summarizes the
X
(µi − µ̂i )2 + (σi2 − σ̂i2 )2 + λr RE
 
L(u0 ; θ, φ) =
metrics and wall clock timings for the baselines, proposed i=1
regularizers and their compositions with previously pro- (17)
Opening the Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics

Method Mean Squared Loss Train Time (s) Prediction Time (s) NFE
Vanilla NSDE 0.0217 ± 0.0088 178.95 ± 20.22 0.07553 ± 0.0186 528.67 ± 6.11
SRNSDE (Ours) 0.0204 ± 0.0091 166.42 ± 14.51 0.07250 ± 0.0017 502.00 ± 4.00
ERNSDE (Ours) 0.0227 ± 0.0090 173.43 ± 04.18 0.07552 ± 0.0008 502.00 ± 4.00

Table 3. Spiral SDE The ERNSDE attains a relative loss of 4% compared to vanilla Neural SDE while reducing the training time and
number of function evaluations. Using SRNSDE reduces both the training and prediction times by 7% and 4% respectively.

Figure 5. Fitting a Neural SDE on Spiral SDE Data. Regular-


izing has minimal effect on the learned dynamics with reduced
training and prediction cost.

The models were trained using AdaBelief Opti-


mizer (Zhuang et al., 2020) with a learning rate of
0.01 for 250 iterations. We generate 100 trajectories for Figure 6. Number of Function Evaluations and Training Er-
each iteration to compute the µ̂i s and σ̂i2 s. ror for Supervised MNIST Classification using Neural SDE
Results Table 3 summarizes the final results for the trained ERNSDE reduces the NFE below 300 with minimal error change
while the unregularized version has NFE ∼ 400.
models for 3 different random seeds. We notice that even
for this “toy” problem, we can marginally improve training
time while incurring a minimal penalty on the final loss. mean logits across 10 trajectories. For Error Estimate and
Stiffness Regularization, we use constant coefficients 10.0
4.2.2. S UPERVISED C LASSIFICATION and 0.1 respectively.
Training Details We train a Neural SDE model to map Results Figure 6 shows the variation in NFE and Training
flattened MNIST Images to their corresponding labels. Our Error during training. Table 4 summarizes the final met-
diffusion function uses a two layered neural network fθ2 and rics and timings for all the trained models. We observe
the drift function is a linear map gθ3 . We use two additional that SRNSDE doesn’t improve the training/prediction time,
linear maps – aθ1 mapping the flattened image to the hidden similar to the MNIST Neural ODE Experiment 4.1.1. How-
dimension and bθ4 mapping the output of the Neural SDE ever, ERNSDE gives us a training and prediction speedup of
to the logits. 33.7% and 52.02% respectively, at the cost of 0.7% reduced
test accuracy.
aθ1 (x, t) = W1 x + B1 (18)
5. Discussion
fθ2 (x, t) = W3 tanh(W2 x + B2 ) + B3 (19)
gθ3 (x, t) = W4 x + B4 (20) Numerical analysis has had over a century of theoretical
developments leading to efficient adaptive methods for solv-
bθ4 (x, t) = W5 x + B5 (21) ing many common nonlinear equations such as differential
equations. Here we demonstrate that by using the knowl-
where the parameters W1 ∈ R32×784 , B1 ∈ R32 , W2 ∈
edge embedded within the heuristics of these methods we
R32×64 , B2 ∈ R64 , W3 ∈ R32×64 , B3 ∈ R32 , W4 ∈
can accelerate the training process of neural ODEs.
R10×32 , and B3 ∈ R10 . We use a batch size of 512 and
train the model for 40 epochs using Adam (Kingma & Ba, We note that on the larger sized PhysioNet and MNIST ex-
2014) with learning rate of 0.01, and an inverse decay of amples we saw significant speedups while on the smaller
10−5 per iteration. While making predictions we use the differential equation examples we saw only minor perfor-
Opening the Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics

Method Train Accuracy (%) Test Accuracy (%) Train Time (hr) Prediction Time (s) NFE
Vanilla NSDE 98.97 ± 0.11 96.95 ± 0.11 6.32 ± 0.19 15.07 ± 0.93 411.33 ± 6.11
SRNSDE (Ours) 98.79 ± 0.12 96.80 ± 0.07 8.54 ± 0.37 14.50 ± 0.40 382.00 ± 4.00
ERNSDE (Ours) 98.16 ± 0.11 96.27 ± 0.35 4.19 ± 0.04 07.23 ± 0.14 184.67 ± 2.31

Table 4. MNIST Image Classification using Neural SDE ERNSDE obtains a training and prediction speedup of 33.7% and 52.02%
respectively, at only 0.7% reduced prediction accuracy.
mance improvements. This showcases how the NFE be- require identification of dynamical mechanisms. However,
comes a better estimate of the total compute time as the cost if the purpose is to learn the true governing dynamics a
of the ODE f (and SDE g) increase when the model size physical system from timeseries data, this form of regular-
increases. ization would bias the result, dampening higher frequency
responses leading to an incorrect system identification. Ap-
This result motivates efforts in differentiable programming
proaches which embed neural networks into solvers could
(Wang et al., 2018; Abadi & Plotkin, 2019; Rackauckas
be used in such cases (Shen et al., 2020; Poli et al., 2020).
et al., 2020a) which enables direct differentiation of solvers
Indeed we note that such Hypereuler approaches could be
since utilizing the solver’s heuristics may be crucial in the
combined with the ERNODE regularization on machine
development of advanced techniques. This idea could be
learning prediction problems, which could be a fruitful av-
straightforwardly extended not only to other forms of differ-
enue of research. Lastly, we note that while either the local
ential equations, but also to other “implicit layer” machine
error and stiffness regularization was effective on each cho-
learning methods. For example, Deep Equilibrium Models
sen equation, neither was effective on all equations and at
(DEQ) (Bai et al., 2019) model the system as the solution to
this time there does not seem to be a clear a priori indicator
an implicit function via a nonlinear solver like Bryoden or
as to which regularization is necessary for a given problem.
Newton’s method. Heuristics like the ratio of the residuals
While it seems the error regularization was more effective
have commonly been used as a convergence criterion and
on the image classification tasks while the stiffness regu-
as a work estimate for the difficulty of solving a particular
larization was more effective on the time series task, we
nonlinear equation (Wanner & Hairer, 1996), and thus could
believe more experiments will be required in order to ascer-
similarly be used to regularize for learning DEQs whose
tain whether this is a common phenomena, possibly worthy
forward passes are faster to solve. Similarly, optimization
of theoretical investigation.
techniques such as BFGS (Kelley, 1999) contain internal
estimates of the Hessian which can be used to regularize 7. Conclusion
the stiffness of “optimization as layers” machine learning
architectures like OptNet (Amos & Kolter, 2017). However, Our studies reveal that error estimate regularization provides
in these cases we note that continuous adjoint techniques a consistent way to improve the training/prediction time of
have a significant computational advantage over discrete ad- neural differential equations. In our experiments, we see an
joint methods because the continuous adjoint method can be average improvement of 1.4x training time and 1.8x predic-
computed directly at the point of the solution while discrete tion time on using error estimate regularization. Overall we
adjoints would require differentiating through the iteration provide conclusive evidence that cheap and accurate cost
process. Thus while a similar regularization would exist in estimates obtained by white-boxing differential equation
these contexts, in the case of differential equations the con- solvers can be as effective as expensive higher-order regu-
tinuous and discrete adjoints share the same computational larization strategies. Together these results demonstrate a
complexity which is not the case in methods which iterate generalizable idea for how to combine differentiable pro-
to convergence. Further study of these applications would gramming with algorithm heuristics to improve training
be required in order to ascertain the effectiveness in acceler- speeds in a way that cannot be done with continuous adjoint
ating the training process, though by extrapolation one may techniques. Thus, even if a derivative can be defined for a
guess that at least the forward pass would be accelerated. given piece of code, our approach shows that differentiating
the solver can still have major advantages because the solver
6. Limitations internal details in terms of stability and performance.
While these experiments have demonstrated major perfor- 8. Acknowledgements
mance improvements, it is pertinent to point out the limi-
tations of the method. One major point to note is that this This material is based upon work supported by the National
only applies to learning neural ODEs for maps z(0) 7→ z(1) Science Foundation under grant no. OAC-1835443, grant
as is used in machine learning applications of the archi- no. SII-2029670, grant no. ECCS-2029670, grant no. OAC-
tecture (Chen et al., 2018). Indeed, a neural ODE as an 2103804, and grant no. PHY-2021825. We also gratefully
“implicit layer” for predictions in machine learning does not acknowledge the U.S. Agency for International Develop-
ment through Penn State for grant no. S002283-USAID.
Opening the Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics

The information, data, or work presented herein was funded in neural information processing systems, pp. 6571–6583,
in part by the Advanced Research Projects Agency-Energy 2018.
(ARPA-E), U.S. Department of Energy, under Award Num-
ber DE-AR0001211 and DE-AR0001222. We also grate- Dauvergne, B. and Hascoët, L. The data-flow equations
fully acknowledge the U.S. Agency for International Devel- of checkpointing in reverse automatic differentiation. In
opment through Penn State for grant no. S002283-USAID. International Conference on Computational Science, pp.
The views and opinions of authors expressed herein do not 566–573. Springer, 2006.
necessarily state or reflect those of the United States Govern-
ment or any agency thereof. This material was supported by Dormand, J. R. and Prince, P. J. A family of embedded
The Research Council of Norway and Equinor ASA through runge-kutta formulae. Journal of computational and ap-
Research Council project ”308817 - Digital wells for op- plied mathematics, 6(1):19–26, 1980.
timal production and drainage”. Research was sponsored Fehlberg, E. Classical fifth-, sixth-, seventh-, and eighth-
by the United States Air Force Research Laboratory and order Runge-Kutta formulas with stepsize control. Na-
the United States Air Force Artificial Intelligence Acceler- tional Aeronautics and Space Administration, 1968.
ator and was accomplished under Cooperative Agreement
Number FA8750-19-2-1000. The views and conclusions Finlay, C., Jacobsen, J.-H., Nurbekyan, L., and Oberman,
contained in this document are those of the authors and A. M. How to train your neural ode. arXiv preprint
should not be interpreted as representing the official poli- arXiv:2002.02798, 2020.
cies, either expressed or implied, of the United States Air
Force or the U.S. Government. The U.S. Government is au- Gholami, A., Keutzer, K., and Biros, G. Anode: Uncondi-
thorized to reproduce and distribute reprints for Government tionally accurate memory-efficient gradients for neural
purposes notwithstanding any copyright notation herein. odes. arXiv preprint arXiv:1902.10298, 2019.

Grathwohl, W., Chen, R. T., Bettencourt, J., Sutskever, I.,


References and Duvenaud, D. Ffjord: Free-form continuous dy-
Abadi, M. and Plotkin, G. D. A simple differentiable pro- namics for scalable reversible generative models. arXiv
gramming language. Proceedings of the ACM on Pro- preprint arXiv:1810.01367, 2018.
gramming Languages, 4(POPL):1–28, 2019.
Hairer, E., Norsett, S., and Wanner, G. Solving Ordi-
Amos, B. and Kolter, J. Z. Optnet: Differentiable opti- nary Differential Equations I: Nonstiff Problems, vol-
mization as a layer in neural networks. In International ume 8. 01 1993. ISBN 978-3-540-56670-0. doi:
Conference on Machine Learning, pp. 136–145. PMLR, 10.1007/978-3-540-78862-1.
2017.
He, K., Zhang, X., Ren, S., and Sun, J. Deep residual
Ascher, U. M. and Petzold, L. R. Computer methods for learning for image recognition, 2015.
ordinary differential equations and differential-algebraic
equations, volume 61. Siam, 1998. Higham, D. J. and Trefethen, L. N. Stiffness of odes. BIT
Numerical Mathematics, 33(2):285–303, 1993.
Bai, S., Kolter, J. Z., and Koltun, V. Deep equilibrium
models. arXiv preprint arXiv:1909.01377, 2019. Hutchinson, M. F. A stochastic estimator of the trace of the
influence matrix for laplacian smoothing splines. Com-
Behl, H., Ghosh, A., Dupont, E., Torr, P., and Nambood- munications in Statistics-Simulation and Computation,
iri, V. Steer: simple temporal regularization for neural 18(3):1059–1076, 1989.
odes. pp. 1–13. Neural Information Processing Systems
Foundation, Inc., 2020. Innes, M. Don’t unroll adjoint: Differentiating ssa-form
programs. arXiv preprint arXiv:1810.07951, 2018.
Bettencourt, J., Johnson, M. J., and Duvenaud, D. Taylor-
mode automatic differentiation for higher-order deriva- Innes, M., Saba, E., Fischer, K., Gandhi, D., Rudilosso,
tives in jax. 2019. M. C., Joy, N. M., Karmali, T., Pal, A., and
Shah, V. Fashionable modelling with flux. CoRR,
Bezanson, J., Edelman, A., Karpinski, S., and Shah, V. B. abs/1811.01457, 2018. URL http://arxiv.org/
Julia: A fresh approach to numerical computing. SIAM abs/1811.01457.
Review, 59(1):65–98, 2017. doi: 10.1137/141000671.
Jeisman, J. I. Estimation of the parameters of stochastic
Chen, R. T., Rubanova, Y., Bettencourt, J., and Duvenaud, differential equations. PhD thesis, Queensland University
D. K. Neural ordinary differential equations. In Advances of Technology, 2006.
Opening the Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics

Kelley, C. T. Iterative methods for optimization. SIAM, Rackauckas, C., Innes, M., Ma, Y., Bettencourt, J., White,
1999. L., and Dixit, V. Diffeqflux.jl - A julia library for neu-
ral differential equations. CoRR, abs/1902.02376, 2019.
Kelly, J., Bettencourt, J., Johnson, M. J., and Duvenaud,
URL http://arxiv.org/abs/1902.02376.
D. Learning differential equations that are easy to solve.
arXiv preprint arXiv:2007.04504, 2020.
Rackauckas, C., Edelman, A., Fischer, K., Innes, M., Saba,
Kidger, P., Chen, R. T. Q., and Lyons, T. ”hey, that’s not an E., Shah, V. B., and Tebbutt, W. Generalized physics-
ode”: Faster ode adjoints with 12 lines of code, 2020. informed learning through language-wide differentiable
programming. 2020a.
Kingma, D. P. and Ba, J. Adam: A method for stochastic
optimization, 2014. URL http://arxiv.org/abs/ Rackauckas, C., Ma, Y., Martensen, J., Warner, C., Zubov,
1412.6980. cite arxiv:1412.6980Comment: Published K., Supekar, R., Skinner, D., Ramadhan, A., and Edelman,
as a conference paper at the 3rd International Conference A. Universal differential equations for scientific machine
for Learning Representations, San Diego, 2015. learning, 2020b.
Kutta, W. Beitrag zur naherungsweisen integration totaler
differentialgleichungen. Z. Math. Phys., 46:435–453, Rubanova, Y., Chen, R. T., and Duvenaud, D. Latent
1901. odes for irregularly-sampled time series. arXiv preprint
arXiv:1907.03907, 2019.
Liu, X., Xiao, T., Si, S., Cao, Q., Kumar, S., and Hsieh,
C.-J. Neural sde: Stabilizing neural ode networks with Runge, C. Über die numerische auflösung von differential-
stochastic noise, 2019. gleichungen. Mathematische Annalen, 46(2):167–178,
Lu, Y., Zhong, A., Li, Q., and Dong, B. Beyond finite layer 1895.
neural networks: Bridging deep architectures and numer-
ical differential equations. In International Conference Shampine, L. F. Stiffness and nonstiff differential equation
on Machine Learning, pp. 3276–3285. PMLR, 2018. solvers, ii: Detecting stiffness with runge-kutta methods.
ACM Trans. Math. Softw., 3(1):44–53, March 1977. ISSN
Lück, A. and Wolf, V. Generalized method of moments 0098-3500. doi: 10.1145/355719.355722. URL https:
for estimating parameters of stochastic reaction networks. //doi.org/10.1145/355719.355722.
BMC systems biology, 10(1):1–12, 2016.
Onken, D. and Ruthotto, L. Discretize-optimize vs. Shampine, L. F. and Gear, C. W. A user’s view of solving
optimize-discretize for time-series regression and continu- stiff ordinary differential equations. SIAM review, 21(1):
ous normalizing flows. arXiv preprint arXiv:2005.13420, 1–17, 1979.
2020.
Shampine, L. F. and Thompson, S. Stiff systems. Scholar-
Poli, M., Massaroli, S., Yamashita, A., Asama, H., and Park, pedia, 2(3):2855, 2007.
J. Hypersolvers: Toward fast continuous-depth models.
arXiv preprint arXiv:2007.09601, 2020. Shen, X., Cheng, X., and Liang, K. Deep euler method:
Qian, N. On the momentum term in gradient descent learn- solving odes by approximating the local truncation error
ing algorithms. Neural networks, 12(1):145–151, 1999. of the euler method. arXiv preprint arXiv:2003.09573,
2020.
Rackauckas, C. and Nie, Q. Adaptive methods for stochastic
differential equations via natural embeddings and rejec- Silva, I., Moody, G., Scott, D. J., Celi, L. A., and Mark,
tion sampling with memory. Discrete and continuous R. G. Predicting in-hospital mortality of icu patients:
dynamical systems. Series B, 22(7):2731, 2017. The physionet/computing in cardiology challenge 2012.
In 2012 Computing in Cardiology, pp. 245–248. IEEE,
Rackauckas, C. and Nie, Q. Confederated modular differen-
2012.
tial equation apis for accelerated algorithm development
and benchmarking. Advances in Engineering Software,
Tsitouras, C. Runge–kutta pairs of order 5(4) sat-
132:1–6, 2019.
isfying only the first column simplifying assump-
Rackauckas, C. and Nie, Q. Stability-optimized high or- tion. Computers & Mathematics with Applica-
der methods and stiffness detection for pathwise stiff tions, 62(2):770 – 775, 2011. ISSN 0898-
stochastic differential equations. In 2020 IEEE High Per- 1221. doi: https://doi.org/10.1016/j.camwa.2011.06.
formance Extreme Computing Conference (HPEC), pp. 002. URL http://www.sciencedirect.com/
1–8. IEEE, 2020. science/article/pii/S0898122111004706.
Opening the Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics

Wang, F., Decker, J., Wu, X., Essertel, G., and Rompf, T.
Backpropagation with callbacks: Foundations for effi-
cient and expressive differentiable programming. Ad-
vances in Neural Information Processing Systems, 31:
10180–10191, 2018.
Wanner, G. and Hairer, E. Solving ordinary differential
equations II. Springer Berlin Heidelberg, 1996.

Zhang, H. and Sandu, A. Fatode: a library for forward,


adjoint, and tangent linear integration of odes. SIAM Jour-
nal on Scientific Computing, 36(5):C504–C523, 2014.
Zhang, H., Xue, Y., Zhang, C., and Dong, L. Computing the
high order derivatives with automatic differentiation and
its application in chebyshev’s method. In 2008 Fourth
International Conference on Natural Computation, vol-
ume 1, pp. 304–308. IEEE, 2008.
Zhuang, J., Tang, T., Ding, Y., Tatikonda, S., Dvornek, N.,
Papademetris, X., and Duncan, J. Adabelief optimizer:
Adapting stepsizes by the belief in observed gradients.
Conference on Neural Information Processing Systems,
2020.
Zhuang, J., Dvornek, N. C., sekhar tatikonda, and s Duncan,
J. {MALI}: A memory efficient and reverse accurate
integrator for neural {ode}s. In International Conference
on Learning Representations, 2021. URL https://
openreview.net/forum?id=blfSjHeFM_e.

You might also like