Training Neural ODEs Using Fully Discretized Simultaneous Optimization
Training Neural ODEs Using Fully Discretized Simultaneous Optimization
and use IPOPT—a solver for large-scale nonlinear optimization—to simultaneously optimize
collocation coefficients and neural network parameters. Using the Van der Pol Oscillator as
a case study, we demonstrate faster convergence compared to traditional training methods.
Furthermore, we introduce a decomposition framework utilizing Alternating Direction Method
of Multipliers (ADMM) to effectively coordinate sub-models among data batches. Our results
show significant potential for (collocation-based) simultaneous Neural ODE training pipelines.
scheduling and control applications owing to their flex- x(T ) = x(t0 ) + fθ (x(t), t) dt
t0 (2)
ibility and representation power, e.g., as scale-bridging
models (Tsay and Baldea, 2019). Neural networks can have = ODESolve(x(t0 ), fθ , t0 , T ),
various architectures and consist of learnable weights and
biases that are optimized during training by minimizing a where ODESolve is a numerical IVP solver, and t0 , T
specified loss function. This is commonly achieved using are the beginning and end of the integration interval,
gradient-based algorithms, which iteratively update the respectively. Training is typically based on the accuracy
model parameters by taking steps in the negative direction of the predictions x(T ), requiring the numerical solution
of the gradient of the loss function (Sarker, 2021). of (2) and backpropagation of gradients through the IVP
solver at every iteration. These requirements lead to long
Neural Ordinary Differential Equations (Neural ODEs) training times of Neural ODEs (Lehtimäki et al., 2024).
(Chen et al., 2018) bridge neural networks with dynam-
ical systems modeling, leveraging existing knowledge of Given the above, this work uses spectral numerical meth-
ODEs. These models extend traditional networks to model ods, specifically collocation, for the time integration of
unknown continuous-time dynamics by parameterizing the differential equations in Neural ODE training. Spectral
evolution of system states as a differential equation: methods offer several advantages: they are global as they
approximate over the entire domain, display exponential
dx(t)
= fθ (x(t), t), (1) convergence for smooth problems, and have better accu-
dt racy with a small number of points (Boyd, 2000). Spectral
where x(t) represents the state at time t, and fθ is the methods remain less explored than sequential methods in
neural network parameterized by θ. Compared to more the context of Neural ODEs and have to date mostly been
standard recurrent or convolutional neural networks, neu- limited to approximating derivative targets (Roesch et al.,
ral ODEs are flexible and can incorporate arbitrary time 2021), or for non-simultaneous training (Quaglino et al.,
spacings. The framework has found numerous applica- 2020). The novelty of this paper is that we show collocation
tions, including process control (Luo et al., 2023), reaction can be employed in a simultaneous optimization approach,
modeling (Sorourifar et al., 2023), and parameter estima- i.e., the system dynamics are solved as equality constraints
tion (Bradley and Boukouvala, 2021; Dua and Dua, 2012). rather than by iterative simulation, for fast and stable
Neural ODE training. Furthermore, we show that the
⋆ Support from a BASF/Royal Academy of Engineering Senior proposed method may produce more parsimonious models
Research Fellowship is gratefully acknowledged. and is amenable to batching via ADMM.
2. NEURAL ODES FOR TIME SERIES 3. SPECTRAL METHODS
Neural ODEs can be applied in various contexts, e.g., in As an alternative to the above, spectral numerical methods
generative modeling or as implicit layers in larger models. approximate the ODE solution as a linear combination of
We focus on the basic, control-relevant setting of model- basis functions, e.g., trigonometric or orthogonal polyno-
−1
ing time-series data comprising observations {yi , ti }N i=0 ,
mials. The function coefficients are fitted over the integra-
d
where yi ∈ R is the data vector at time ti , and T = tN −1 tion domain, offering high accuracy and convergence rates
is the end of the time (and integration) interval. Our for smooth problems (Boyd, 2000).
objective is to learn a parametric ODE model that, when 3.1 Collocation with Lagrange Interpolation
integrated from an initial condition, results in a continuous
trajectory y(t) approximating the observed data: Collocation is a class of spectral methods in which an
y′ (t) = fθ (y(t), t), y(t0 ) = y0 , t ≥ t0 , (3) ODE is enforced at a set of discrete points, termed the
−1
where y′ (t) is the time derivative of the system state collocation grid, which we introduce as ξ = {ξi }N i=0 in
y(t) and fθ is a neural network parameterized by θ. A [t0 , T ]. Under this framework, the approximate solution is
trained Neural ODE model may also predict beyond the N
X −1
observed data interval [t0 , T ], provided the ODE solution ỹ(t) = βi ϕi (t),
remains valid and fθ satisfies conditions such as Lipschitz i=0
continuity. where {ϕi (t)}N −1 N −1
i=0 represent the basis functions, {βi }i=0
During training, the solution Ŷ is computed by numeri- are the coefficients to be determined, and t ∈ [t0 , T ].
cally solving the IVP: We employ the barycentric form of Lagrange polynomials
Ŷ = ODESolve (fθ , y0 , t) ∈ RN ×d , (4) as basis functions due to its numerical stability and adapt-
where t = {ti }N −1 ability to diverse functions, including non-periodic behav-
i=0 is the vector of time points matching iors (Berrut and Trefethen, 2004). The use of Lagrange
the observations. The model parameters, θ, are learned by
minimizing a loss function that captures the discrepancy polynomials also offers an implementation simplification
between the predictions and observed data. The mean due to the interpolation property, which ensures that the
squared error (MSE) loss function is a common choice for coefficients coincide with the true state values at the col-
continuous-output regression: location grid, such that βi = yi at each collocation point.
2
As a result, the approximated solution becomes:
1
Lθ (Ŷ, Y) = Ŷ − Y , (5) N −1
N
X
F ỹ(t) = yi ℓi (t), (6)
where Y ∈ RN ×d is the matrix of observed values and i=0
∥·∥F is the Frobenius norm. In summary, the goal of Neural N −1
where {ℓi (t)}i=0 are the Lagrange basis functions and
ODE training can be formulated as computing θ such that N −1
(3) holds while minimizing (5). {yi }i=0 are the unknown coefficients, which are also the
state values at each point ξi . We assume y to be unknown,
2.1 Sequential ODE Solvers given the presence of noise in real-life systems. By treating
all yi as coefficients (Berrut and Trefethen, 2004) and
In a typical Neural ODE training pipeline, we ensure (3) differentiating the interpolation formula (6), we obtain:
holds by solving the ODE system in every iteration. In this N
X −1
sequential approach, the solver computes Ŷ iteratively, ỹ′ (t) = yi ℓ′i (t). (7)
e.g., by time stepping, until the end of the interval is i=0
reached at T . The simplest example is Euler’s method, Substituting (6) and (7) into the Neural ODE (3) and
while more commonly used schemes are the Runge-Kutta evaluating at each collocation point ξi yields:
methods, used as default solvers in torchdiffeq within
N −1 N −1
PyTorch and Diffrax within JAX. We can generalize a X X
step of a sequential numerical scheme as: yj ℓ′j (ξi ) = fθ yj ℓj (ξi ), ξi , i = 0, . . . , N − 1.
j=0 j=0
y(t + h) = y(t) + h · Φ(f, y(t), t, h),
(8)
where h is the step size, f is the derivative function, and Φ
denotes the (typically explicit) function that approximates This results in a system of nonlinear equations with respect
N −1
the change in y over the interval from t to t + h. to the unknowns {yj }j=0 . The system can be expressed
While this framework enables using tailored simulation in matrix form:
methods for (3), using sequential ODE solvers for training DY = Fθ (Y, ξ), (9)
poses several challenges. First, numerical errors can accu- where:
mulate at each integration step, resulting in substantial
global errors. Although adaptive solvers help control these • D ∈ RN ×N is the differentiation matrix with elements
errors, they add computational overhead and may still Dij = ℓ′j (ξi ).
behave unpredictably on unseen data. Second, in addition • Y = [y0 , . . . , yN −1 ]⊤ ∈ RN ×d is the matrix of
to simulation CPU times, storing intermediate solutions unknown coefficients (true state values).
for backpropagation requires significant memory. Despite • Fθ (Y, ξ) = [fθ (ỹ(ξ0 ), ξ0 ), . . . , fθ (ỹ(ξN −1 ), ξN −1 )]⊤ ∈
the adjoint method (Chen et al., 2018) ensuring constant RN ×d contains the Neural ODE evaluated at each
memory cost, it can substantially prolong training time. collocation point.
Using the barycentric formula (Berrut and Trefethen, We initialize two optimization variable groups within Py-
2004), the differentiation matrix is defined as: omo: state variables and neural network parameters.
wj 1
, if i ̸= j, • State Variables, Y∗ , represent the system’s state
w i ξi − ξj at collocation time points. To expedite training, the
N −1 state variables can also be initialized using smoothed
Dij = X
− Dik , if i = j, observed data. These variables aim to approximate
k=0 the true values Y from the observed values Yobs .
k̸=i
• Neural Network Parameters, θ, include the
where the weight is computed as: weights and biases of the neural network.
1
wi = N −1 . The objective function captures the difference between
Y
(ξi − ξk ) the observed data and the state variables approximated
by the collocation equations, instead of the output of
k=0 k̸=i
ODESolve as in (4). We express the objective function as
The selection of the collocation grid significantly impacts a combination of the MSE loss and regularization terms:
the accuracy of the method. To mitigate errors caused by 1
L(Y∗ , Yobs ) = ∥Y∗ − Yobs ∥2F + λ∥θ∥22 , (11)
Runge’s phenomenon, Chebyshev nodes of the second kind N
in [−1, 1] are often used:
where
iπ
ξi = cos , i = 0, . . . , N − 1. (10) • Y∗ ∈ RN ×d is the matrix of estimated state variables
N −1
(the variables being optimized).
We refer the interested reader to Young (2019) for a • Yobs ∈ RN ×d is the matrix of observed values.
comprehensive discussion of collocation grids. • θ is the vector of neural network parameters.
• ∥ · ∥2 and ∥ · ∥F are the Euclidean (L2) and Frobenius
4. PROPOSED METHODOLOGY norms respectively.
• λ is the regularization parameter.
So far, we have converted the continuous ODE problem We enforce consistency between the neural network and
into a discrete algebraic system (9) by incorporating collo- the derivative of Y∗ at each collocation point ξi :
cation and Lagrange interpolation. Our goal is to minimize N −1
the loss function (5) while enforcing that the collocation-
X
yj∗ lj′ (ξi ) = fθ (yi∗ , ξi ).
estimated derivatives, computed as DY, match the neural-
j=0
network-predicted derivatives Fθ (Y, ξ) at the collocation
points. The challenge arises because both the true values Here, ℓ′j (ξi )
is the element Dij of the differentiation matrix,
of the state values Y and the parameters θ of the neural so the left-hand side approximates the derivative of the
network are unknown. optimized state values. The right-hand side is the output
of the neural network.
4.1 Simultaneous Approach
4.3 Problem Formulation
Our proposed approach is to incorporate the collocation
system (5) as equality constraints in a single nonlinear We formulate the optimization problem as follows:
optimization framework, where the objective function cap- min
∗
L(Y∗ , Yobs ),
Y ,θ
tures the discrepancy between the observed and optimized
states, e.g., the MSE. By solving for Y and θ simul- Subject to:
taneously, we effectively train the neural network from Equality Constraints: DY∗ = Fθ (Y∗ , ξ),
∗
observed data while enforcing the collocation constraints. Bounds: yL ≤ ỹ∗ ≤ yU
∗
, θL ≤ θ ≤ θR
The simultaneous approach in the context of collocation- where:
based dynamic optimization can be further explored in • L : is the loss function as described in (11).
the works by Tjoa and Biegler (1991); Kameswaran and • C : DY∗ = Fθ (Y∗ , ξ) is the matrix of equality
Biegler (2006), where it is applied to the problem of constraints.
parameter-estimation in differential equation systems. • Y∗ and θ represent decision variables.
∗ ∗
• yL , yU and θL , θU are the respective values for their
4.2 Implementation lower and upper bounds
To implement this methodology, we utilize the Interior
Point OPTimizer (IPOPT), which is well-suited for solv- After the model is solved to optimality, the neural network
ing continuous, large-scale nonlinear optimization prob- can be used as the RHS of an ODE in a sequential or
lems (Biegler and Zavala, 2009). For the software imple- collocation-based solver in the post-training context.
mentation, we call IPOPT through the open-source Pyomo
algebraic modeling language. Recent research (Ceccon 4.4 Alternating Direction Method of Multipliers (ADMM)
et al., 2022) demonstrates how neural networks can be
represented as constraints within the Pyomo framework. One potential disadvantage of the above simultaneous
framework is that the entire dataset must be handled in a
single optimization problem, while many training pipelines 2010). For the simultaneous approach, the state vari-
divide data into batches to alleviate computational or ables Y∗ are initialized to the values of locally
memory burden. We propose using the Alternating Di- weighted polynomial regression (Cleveland and De-
rection Method of Multipliers (ADMM) to enable multi- vlin, 1988).
batching by coordinating the training of separate submod-
els. For two ‘batches,’ the problem can be written as:
min L1 (θ1 , Y1 ) + L2 (θ2 , Y2 ) 5.1 Case Study: Van der Pol Oscillator
θ1 ,θ2
s.t. θ1 = θ2 , The forced Van der Pol Oscillator is a 2-D ODE system
where Y1 , Y2 are the batches of data, θ1 , θ2 are vectors that can be represented as two coupled first-order equa-
containing parameters of each sub-model, and L1 , L2 are tions:
the loss functions. Notice the ‘linking’ constraints θ1 = θ2 ′
u = v, u0 = 0
enforce a consensus model between the two data batches. v ′ = µ(1 − u2 )v − u + A cos(ωt), v0 = 1,
ADMM decomposes the above problem without the link-
ing constraints by updating the optimization parameters where u is the displacement, v is the velocity, ω is the
and a dual variable (Lagrange multiplier) in an iterative angular frequency, µ is the damping parameter, and A is
manner (Boyd et al., 2010). Without the constraints, the external periodic force. For our experiments, we set
the problem is effectively decomposed into independent the initial conditions as u0 = 0 and v0 = 1. The remaining
subproblems minθi Li (θi , Yi ). The loss functions for the parameters are chosen as µ = 1, A = 1 and ω = 1.
subproblems are reformulated as follows:
ρ ui
2 Training and Inference Procedure After training with
LADMM,i = Li (θi , Xi )+ θi − θ̄ (k) + ,
for i = 1, 2, our proposed collocation-based framework, the learned
2 ρ
2 Neural ODE is used in a standard ODE solver (JAX
where θi are parameters of submodel i, θ̄ (k) are the Diffrax) for forward simulation. Figure 1 illustrates the
consensus parameters at the k-th ADMM iteration, ρ is prediction on both training and test ranges, showing
a scalar penalty strength, and ui are the dual variables that the collocation-trained Neural ODE captures the
associated with subproblem i. For two submodels, the underlying dynamics effectively.
consensus weights in the k-th iteration are given by:
(k) (k)
θ1 + θ2
θ̄ (k) =
2
The dual variables for each submodel i are updated in each
iteration using:
(k+1) (k) (k)
ui = ui + ρ(θi − θ̄ (k) ).