Continuous Adjoint Method
Continuous Adjoint Method
Many popular mathematical models, such as common hidden Markov models, utilize se-
quences of discrete states implicitly defined through forward difference equations,
to capture the regular evolution of a latent system; here un denotes the nth latent state of
the system and ψ the model parameters. Typically these sequences are incorporated into
larger models through discrete functionals that consume particular sequences and return
scalar values,
NX−1
J (ψ) = jn (un , ψ, n).
n=0
We can quantify the impact of the parameters, ψ, on these functionals by evaluating the
total derivatives, dJ /dψ. The evaluation of these derivatives is complicated by the depen-
dence of the sequences on the parameters enforced by the forward difference equations; the
total derivative of a functional has to take into account both the explicit dependence of
the jn on ψ and also the implicit dependence mediated by the latent states un .
Michael Betancourt is the principal research scientist at Symplectomorphic, LLC. (e-mail:
[email protected]). Charles Margossian is a PhD candidate in the
Department of Statistics, Columbia University. Vianey Leos-Barajas is a postdoctoral
researcher in the Department of Forestry and Environmental Resources and the
Department of Statistics at North Carolina State University.
1
2 BETANCOURT ET AL.
We can always compute each sensitivity, dun /dψ, by propagating derivatives along the
forward difference equations and constructing the corresponding sequence of sensitivities.
This quickly becomes expensive, however, when there are many parameters that each re-
quire their own sensitivities. In order to better scale we need to bypass the superfluous
computation of these intermediate derivatives and only propagate the minimal informa-
tion needed to construct the total derivatives of the desired functionals.
In this paper we introduce a discrete adjoint technique that efficiently computes total
derivatives without explicitly calculating intermediate sensitivities. We begin by reviewing
the powerful continuous adjoint method for ordinary differential equations before deriving a
discrete analog. Finally we demonstrate how the method can be applied to hidden Markov
models.
u(t = 0) = υ(ψ).
A functional consumes the state trajectory and returns a single real number through an
integration over time, Z T
J (ψ) = dt j(u, ψ, t).
0
Our goal is then to compute the total derivative of J with respect to the parameter
ψ, taking into account not only the explicit dependence of ψ on j but also the implicit
dependence through the influence of ψ on the evolution of the states u(t). For a thorough
review of the possible strategies see Section 2.6 and 2.7 of Hindmarsh and Serban (2020).
1.1 Adjoint Task Force
An immediate way to compute gradients of functionals like this is to explicitly compute
the state sensitivities
η = du/dψ
THE DISCRETE ADJOINT METHOD 3
Once we’ve solved for the state sensitivities we can construct the total derivative of the
desired functional through the chain rule,
Z T
dJ d
(ψ) = dt j(u, ψ, t)
dψ dψ 0
Z T
d
= dt j(u, ψ, t)
0 dψ
Z T " † #
∂j dj
= dt + ·η .
0 ∂ψ du
This approach becomes burdensome, however, once we consider multiple parameters and
hence multiple total derivatives, each of which requires integrating over its own trajectory
of sensitivities.
Another way to work out the total derivative of the functional is to treat the influence of
the parameter on the state trajectory as constraints (Hannemann-Tamás, Muñoz and Marquardt,
2015),
0 = u(0) − υ(ψ)
du
0= − f (u, ψ, t),
dt
4 BETANCOURT ET AL.
which are explicitly incorporated into the functional with Lagrange multipliers, µ and λ(t),
Z T
J (ψ) = dt j(u, ψ, t)
0
Z T
=0+ dt j(u, ψ, t) + 0
0
Z T
du
= µ† · [u(0) − υ(ψ)] + dt j(u, ψ, t) + λ† (t) · − f (u, ψ, t)
0 dt
≡ L(ψ).
As long as the constraints are satisfied this modified functional will equal our target func-
tional for any values of the Lagrange multipliers.
Under these constraints we can compute the total derivative of the functional by instead
differentiating this modified functional. If we assume that everything is smooth then we
can exchange the order of integration and differentiation to give
dJ dL
=
dψ dψ
Z T
† du dυ dj † d du df
=µ · (0) − + dt + λ (t) · −
dψ dψ 0 dψ dψ dt dψ
Z T " † # " † #
† du ∂υ ∂j ∂j du † d du ∂f ∂f du
=µ · (0) − + dt + · + λ (t) · − − · .
dψ ∂ψ 0 ∂ψ ∂u dψ dt dψ ∂ψ ∂u dψ
Once again a boldfaced fraction is shorthand for a Jacobian matrix. For example,
†
∂j ∂j ∂j
= ,..., .
∂u ∂u1 ∂uN
The benefit of this approach is that we can use the freedom in our Lagrange multipliers
to eliminate the expensive state sensitivities entirely! First we need to integrate the time
derivative of the sensitivities by parts to recover a pure sensitivity,
T T †
d du du du dλ du
Z Z
†
dt λ (t) · = λ† (T ) · (T ) − λ† (0) · (0) − dt · .
0 dt dψ dψ dψ 0 dt dψ
Then we substitute this result into the total derivative and gather all the sensitivity terms
THE DISCRETE ADJOINT METHOD 5
together,
dJ du ∂υ du du
= µ† · (0) − + λ† (T ) · (T ) − λ† (0) · (0)
dψ dψ ∂ψ dψ dψ
Z T † † †
∂j ∂j du dλ du † ∂f † ∂f du
+ dt + · − · − λ (t) · − λ (t) · ·
0 ∂ψ ∂u dψ dt dψ ∂ψ ∂u dψ
†
du ∂υ du
= µ − λ(0) · (0) − µ† · + λ† (T ) · (T )
dψ ∂ψ dψ
∂f † du
Z T Z T
∂j † ∂f ∂j dλ
+ dt − λ (t) · + dt − − λ(t) · ·
0 ∂ψ ∂ψ 0 ∂u dt ∂u dψ
Now we can exploit the freedom in our Lagrange multipliers to remove all vestiges of
the sensitivities. First let’s set µ = λ(0) to remove the initial sensitivities and λ(T ) = 0 to
remove the final sensitivities. We can then remove the integral term that depends on the
intermediate sensitivities if we set
∂j dλ ∂f
− − λ(t) · = 0,
∂u dt ∂u
or
dλ ∂j ∂f
= − λ(t) · .
dt ∂u ∂u
In other words provided that λ(t) satisfies the differential equation
dλ ∂j ∂f
= (u, ψ, t) − λ(t) · (u, ψ, t)
dt ∂u ∂u
with the initial conditions
λ(T ) = 0
then then total derivative of our target functional reduces to
Z T
dJ † ∂υ ∂j ∂f
(ψ) = −λ (0) · + dt (u, ψ, t) − λ† (t) · (u, ψ, t).
dψ ∂ψ 0 ∂ψ ∂ψ
The system of differential equations for λ(t) is known as the adjoint system relative to
the original system of ordinary differential equations. If we first solve for u(t) then we
can solve for the adjoint λ(t) and compute the total derivative dJ /dψ at the same time
without having to compute any explicit sensitivities.
1.2 Computational Scalings
For a single parameter the direct approach is slightly more efficient, requiring two N -
dimensional integrations for the states and their sensitivities compared to the adjoint ap-
proach which requires two N -dimensional integrations, one for the states and one for the
6 BETANCOURT ET AL.
adjoint states, and the extra one-dimensional integration to solve for the total derivative.
The adjoint method, however, quickly becomes more efficient as we consider multiple pa-
rameters because the adjoint states are the same for all parameters.
When we have K parameters the forward sensitivity approach requires an N -dimensional
integration for each sensitivity and the total cost scales as N +N ·K. The adjoint approach,
however, requires only two N -dimensional solves to set up the states and the adjoint states
and then K one-dimensional solves for each gradient component, yielding a total cost
scaling of 2N + K.
Comparing these two scalings we see that the adjoint method is better when
N
< K,
N −1
a condition verified for any N provided that K ≥ 2. In other words the adjoint method
will generally feature the highest performance in any application with at least two param-
eters. As the number of parameters increases the O(N K) scaling of the forward sensitivity
approach grows much faster than the O(N + K) scaling of the adjoint method, and the
performance gap only becomes more substantial.
1.3 An Application to Automatic Differentiation
A particularly useful application of the continuous adjoint method is for the reverse mode
automatic differentiation (Bücker et al., 2006; Griewank and Walther, 2008; Margossian,
2019) of functions incorporating the solutions of ordinary differential equations. In order
to propagate the needed differential information through the composite function we need
to be able to evaluate the Jacobian of the final state with respect to the parameters,
du
(T ),
dψ
contracted against a vector, δ,
du
δ† · (T ),
dψ
where † denotes transposition. This arises, for example, when computing the gradient of a
scalar function, for example a probability density or an objective function, which implicitly
depends on ψ through u.
We can recover the above contraction by defining the integrand
j(u, ψ, t) = δ † · f (u, ψ, t)
THE DISCRETE ADJOINT METHOD 7
As in the discrete case we can exploit the freedom in our Lagrange multipliers to remove
all of the sensitivity terms. We first set
∂j0 ∂∆0
µ+ − λ0 · − λ0 = 0,
∂u0 ∂u0
or
∂j0 ∂∆0
µ=− + λ0 · + λ0 ,
∂u0 ∂u0
and then
λN −1 = 0
to remove all the sensitivities outside of the summations. We then eliminate the second
summation by choosing the rest of the λn to satisfy
∂jn ∂∆n
− λn + λn−1 − λn · = 0,
∂un ∂un
or equivalently
∂jn+1 ∂∆n+1
− λn+1 + λn − λn+1 · = 0.
∂un+1 ∂un+1
This defines an adjoint system defined by the backward difference equations
∂jn+1 ∂∆n+1
λn − λn+1 = − + λn+1 ·
∂un+1 ∂un+1
10 BETANCOURT ET AL.
yn−1 yn yn+1
zn−1 zn zn+1
Fig 1. The conditional dependence structure of a hidden Markov model admits efficient marginalization of
the discrete hidden states into state probabilities. Derivatives of the state probabilities with respect to the
model parameters also have to navigate this conditional dependence structure.
ωn,i ≡ π(yn | zn = i)
Γn,ij ≡ π(zn+1 = i | zn = j)
we can marginalize the hidden states into the forward state probabilities
Because of the defining conditional structure these state probabilities satisfy the recursion
relation
αn+1 (ψ) = ω n+1 (ψ) ◦ (Γn+1 (ψ) · αn (ψ)),
where ◦ denotes the element-wise Hadamard product, along with the initial condition
Forward solving the recursion relation efficiently computes each of the state probabilities,
the last of which gives the desired marginal likelihood
M
X
π(y1 , . . . , yN , ψ) = αN,m (ψ) = 1† · αN (ψ).
m=1
N −1
"" j+2 # #† " #
dΩj+1 dΓj+2
Γ†i+1 · Ω†i
X Y
= ·1 · · Γj+2 + Ωj+1 · · αj
dψ dψ
j=0 i=N −1
"" 1 # #† " #
Y † dΩ 0 dρ
+ Γi+1 · Ω†i · 1 · · ρ + Ω0 ·
dψ dψ
i=N −1
N −1
"" j+2 # #† " #
X Y † dΩj+1 dΓj+2
= Γi+1 · Ωi · 1 · · Γj+2 + Ωj+1 · · αj
dψ dψ
j=0 i=N −1
"" 1 # #† " #
Y † dΩ0 dρ
+ Γi+1 · Ωi · 1 · · ρ + Ω0 ·
dψ dψ
i=N −1
N −1
" #† " #
X dΩj+1 dΓj+2
= β j+1 · · Γj+2 + Ωj+1 · · αj
dψ dψ
j=0
" #† " #
dΩ0 dρ
+ β0 · · ρ + Ω0 · ,
dψ dψ
THE DISCRETE ADJOINT METHOD 13
A third, novel approach to deriving the marginal likelihood gradient is to interpret the
recursion as a forward difference equation and apply the discrete adjoint method. Let
un = αn and manipulate the defining recursion relation into a forward difference
∆n = ω n+1 ◦ (Γn+1 · αn ) − αn ,
so that
K K
X ∂∆n,i X
(λn,i − 1) = (λn,i − 1) ωn+1,i Γn+1,ij − (λn,j − 1),
∂αn,j
i=1 i=1
or in matrix notation,
∂∆n
(λn − 1) · = Γ†n+1 · (ω n+1 ◦ (λn − 1)) − λn + 1.
∂αn
The backwards updates then become
∂∆n+1
λn − λn+1 = (λn+1 − 1) ·
∂αn+1
λn − λn+1 = Γ†n+2 · (ω n+2 ◦ (λn+1 − 1)) − λn+1 + 1
λn = Γ†n+2 · (ω n+2 ◦ (λn+1 − 1)) + 1.
which is just the backward states encountered above with a shifted index,
κn = β n−1 .
Lastly we work out the boundary term. Recalling υ = ω0 ◦ ρ, the boundary term is
† †
∂j0 ∂∆0 ∂(ω 0 ◦ ρ) ∂∆0 ∂∆0 ∂(ω 0 ◦ ρ)
1+ − λ0 · − λ0 · = 1+1· − λ0 · − λ0 ·
∂α0 ∂α0 ∂ψ ∂α0 ∂α0 ∂ψ
†
∂∆0 ∂(ω 0 ◦ ρ)
= (1 − λ0 ) · + 1 − λ0 ·
∂α0 ∂ψ
" #†
∂(ω 0 ◦ ρ)
= Γ†1 · (ω 1 ◦ (1 − λ0 ) − (1 − λ0 ) + 1 − λ0 ·
∂ψ
" #†
∂(ω 0 ◦ ρ)
= Γ†1 · (ω 1 ◦ (1 − λ0 ) ·
∂ψ
" #†
∂(ω 0 ◦ ρ)
= Γ†1 · (ω 1 ◦ κ0 ) ·
∂ψ
" #†
∂ρ ∂ω 0
= Γ†1 · (ω 1 ◦ κ0 ) · ω 0 ◦ + ◦ρ .
∂ψ ∂ψ
Putting all of this together we can recover the derivative of the marginal likelihood by
computing
d dαN
π(y1 , . . . , yN ) = 1† ·
dψ dψ
† N −1
∂j0 ∂∆0 ∂(ω 0 ◦ ρ) X ∂jn ∂∆n
= 1+ − λ0 · − λ0 · + − λ†n ·
∂α0 ∂α0 ∂ψ ∂ψ ∂ψ
n=0
†
∂ρ ∂ω 0
= Γ†1 · (ω 1 ◦ κ0 ) · ω 0 ◦ + ◦ρ
∂ψ ∂ψ
N −1
X
† ∂ω n+1 ∂Γn+1
+ κn · ◦ Γn+1 · αn + ω n+1 ◦ · αn ,
n=0
∂ψ ∂ψ
4. CONCLUSION
In analogy to the continuous adjoint methods used with ordinary differential equations,
the discrete adjoint method defines a procedure to efficiently evaluate the derivatives of
functionals over the evolution of discrete sequences. Because this procedure is fully de-
fined by the derivatives of the forward difference equations and the summands defining the
functional, it defines an efficient sequential differentiation algorithm that mirrors the struc-
ture of the original sequence. The beneficial scaling of this procedure makes the resulting
implementations especially useful in practical applications.
We can apply the method to any mathematical model that depends on the parameters
through an (implicit) forward difference equation. Once we have made this equation ex-
plicit the derivation of a differentiation algorithm is completely mechanical, minimizing the
burden of its implementation.
ACKNOWLEDGEMENTS
We thank Bob Carpenter for helpful discussions.
REFERENCES
Bücker, H. M., Corliss, G., Hovland, P., Naumann, U. and Norris, B. (2006). Automatic Differen-
tiation: Applications, Theory, and Implementations. Springer.
Cappé, O., Moulines, E. and Rydén, T. (2005). Inference in hidden Markov models. Springer Series in
Statistics. Springer, New York.
Griewank, A. and Walther, A. (2008). Evaluating derivatives, Second ed. Society for Industrial and
Applied Mathematics (SIAM), Philadelphia, PA.
Hannemann-Tamás, R., Muñoz, D. A. and Marquardt, W. (2015). Adjoint sensitivity analysis for
nonsmooth differential-algebraic equation systems. SIAM J. Sci. Comput. 37 A2380–A2402.
Hindmarsh, A. and Serban, R. (2020). User Documentation for CVODES v5.1.0 Technical Report,
Lawrence Livermore National Laboratory.
Margossian, C. C. (2019). A Review of automatic differentiation and its efficient implementation. Wiley
interdisciplinary reviews: data mining and knowledge discovery 9.