0% found this document useful (0 votes)
13 views50 pages

Opt Sem3

The document discusses Automatic Differentiation (AD) methods, specifically forward and reverse modes, highlighting their computational complexities and applications in machine learning optimization. It includes examples of derivative calculations using AD with JAX and explores problems related to linear systems and singular value decomposition. Additionally, it covers gradient propagation in feedforward neural networks and the importance of storing intermediate results for efficient backpropagation.

Uploaded by

Roman Degtyarev
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
13 views50 pages

Opt Sem3

The document discusses Automatic Differentiation (AD) methods, specifically forward and reverse modes, highlighting their computational complexities and applications in machine learning optimization. It includes examples of derivative calculations using AD with JAX and explores problems related to linear systems and singular value decomposition. Additionally, it covers gradient propagation in feedforward neural networks and the importance of storing intermediate results for efficient backpropagation.

Uploaded by

Roman Degtyarev
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 50

Automatic Differentiation.

Seminar

Optimization for ML. Faculty of Computer Science. HSE University

v § } 1
Forward mode

Figure 1: Illustration of forward chain rule to calculate the derivative of the function vi with respect to wk .

• Uses the forward chain rule

Automatic Differentiation v § } 2
Forward mode

Figure 1: Illustration of forward chain rule to calculate the derivative of the function vi with respect to wk .

• Uses the forward chain rule


• Has complexity d × O(T ) operations

Automatic Differentiation v § } 2
Reverse mode

Figure 2: Illustration of reverse chain rule to calculate the derivative of the function L with respect to the node vi .

• Uses the backward chain rule

Automatic Differentiation v § } 3
Reverse mode

Figure 2: Illustration of reverse chain rule to calculate the derivative of the function L with respect to the node vi .

• Uses the backward chain rule


• Stores the information from the forward pass

Automatic Differentiation v § } 3
Reverse mode

Figure 2: Illustration of reverse chain rule to calculate the derivative of the function L with respect to the node vi .

• Uses the backward chain rule


• Stores the information from the forward pass
• Has complexity O(T ) operations

Automatic Differentiation v § } 3
Toy example
ñ Example

f (x1 , x2 ) = x1 ∗ x2 + sin x1
∂f
Let’s calculate the derivatives using forward and reverse modes.
∂xi

Automatic Differentiation Problems v § } 4


Toy example
ñ Example

f (x1 , x2 ) = x1 ∗ x2 + sin x1
∂f
Let’s calculate the derivatives using forward and reverse modes.
∂xi

Figure 3: Illustration of computation graph of f (x1 , x2 ).

Automatic Differentiation Problems v § } 4


Automatic Differentiation with JAX

Example №1

f (X) = tr(AX −1 B)

∇f = −X −T AT B T X −T

Automatic Differentiation Problems v § } 5


Automatic Differentiation with JAX

Example №1 Example №2

f (X) = tr(AX −1 B) g(x) = 1/3 · ||x||32

∇f = −X −T AT B T X −T ∇2 g = ||x||−1 T
2 xx + ||x||2 In

Automatic Differentiation Problems v § } 5


Automatic Differentiation with JAX

Example №1 Example №2

f (X) = tr(AX −1 B) g(x) = 1/3 · ||x||32

∇f = −X −T AT B T X −T ∇2 g = ||x||−1 T
2 xx + ||x||2 In

Let’s calculate the gradients and hessians of f and g in python 3

Automatic Differentiation Problems v § } 5


Problem 1
ñ Question

Which of the AD modes would you choose (forward/ reverse) for the following computational graph of primitive
arithmetic operations?

Figure 4: Which mode would you choose for calculating gradients there?

Automatic Differentiation Problems v § } 6


Problem 2

Suppose, we have an invertible matrix A and a vector b,


the vector x is the solution of the linear system Ax = b,
namely one can write down an analytical solution
x = A−1 b.

ñ Question

∂L ∂L
Find the derivatives , .
∂A ∂b
Figure 5: x could be found as a solution of linear system

Automatic Differentiation Problems v § } 7


Gradient propagation through the linear least squares

Suppose, we have an invertible matrix A and a vector b, the vector x


is the solution of the linear system Ax = b, namely one can write
down an analytical solution x = A−1 b, in this example we will show,
∂L ∂L ∂L
that computing all derivatives , , , i.e. the backward pass,
∂A ∂b ∂x
costs approximately the same as the forward pass.

Figure 6: x could be found as a solution of linear


system

Automatic Differentiation Problems v § } 8


Gradient propagation through the linear least squares

Suppose, we have an invertible matrix A and a vector b, the vector x


is the solution of the linear system Ax = b, namely one can write
down an analytical solution x = A−1 b, in this example we will show,
∂L ∂L ∂L
that computing all derivatives , , , i.e. the backward pass,
∂A ∂b ∂x
costs approximately the same as the forward pass.
It is known, that the differential of the function does not depend on
the parametrization:

∂L ∂L ∂L
D E D E D E
dL = , dx = , dA + , db
∂x ∂A ∂b

Figure 6: x could be found as a solution of linear


system

Automatic Differentiation Problems v § } 8


Gradient propagation through the linear least squares

Suppose, we have an invertible matrix A and a vector b, the vector x


is the solution of the linear system Ax = b, namely one can write
down an analytical solution x = A−1 b, in this example we will show,
∂L ∂L ∂L
that computing all derivatives , , , i.e. the backward pass,
∂A ∂b ∂x
costs approximately the same as the forward pass.
It is known, that the differential of the function does not depend on
the parametrization:

∂L ∂L ∂L
D E D E D E
dL = , dx = , dA + , db
∂x ∂A ∂b
Given the linear system, we have:

Ax = b
Figure 6: x could be found as a solution of linear
system dAx + Adx = db → dx = A−1 (db − dAx)

Automatic Differentiation Problems v § } 8


Gradient propagation through the linear least squares

The straightforward substitution gives us:

∂L −1 ∂L ∂L
D E D E D E
, A (db − dAx) = , dA + , db
∂x ∂A ∂b

Figure 7: x could be found as a solution of linear


system

Automatic Differentiation Problems v § } 9


Gradient propagation through the linear least squares

The straightforward substitution gives us:

∂L −1 ∂L ∂L
D E D E D E
, A (db − dAx) = , dA + , db
∂x ∂A ∂b

∂L T ∂L ∂L ∂L
D E D E D E D E
−A−T x , dA + A−T , db = , dA + , db
∂x ∂x ∂A ∂b

Figure 7: x could be found as a solution of linear


system

Automatic Differentiation Problems v § } 9


Gradient propagation through the linear least squares

The straightforward substitution gives us:

∂L −1 ∂L ∂L
D E D E D E
, A (db − dAx) = , dA + , db
∂x ∂A ∂b

∂L T ∂L ∂L ∂L
D E D E D E D E
−A−T x , dA + A−T , db = , dA + , db
∂x ∂x ∂A ∂b
Therefore:

∂L ∂L T ∂L ∂L
= −A−T x = A−T
∂A ∂x ∂b ∂x

Figure 7: x could be found as a solution of linear


system

Automatic Differentiation Problems v § } 9


Gradient propagation through the linear least squares

The straightforward substitution gives us:

∂L −1 ∂L ∂L
D E D E D E
, A (db − dAx) = , dA + , db
∂x ∂A ∂b

∂L T ∂L ∂L ∂L
D E D E D E D E
−A−T x , dA + A−T , db = , dA + , db
∂x ∂x ∂A ∂b
Therefore:

∂L ∂L T ∂L ∂L
= −A−T x = A−T
∂A ∂x ∂b ∂x
It is interesting, that the most computationally intensive part here is
the matrix inverse, which is the same as for the forward pass.
Figure 7: x could be found as a solution of linear Sometimes it is even possible to store the result itself, which makes
system the backward pass even cheaper.

Automatic Differentiation Problems v § } 9


Problem 3
Suppose, we have the rectangular matrix W ∈ Rm×n ,
which has a singular value decomposition:

W = U ΣV T , U T U = I, V T V = I,
Σ = diag(σ1 , . . . , σmin(m,n) )

The regularizer R(W ) = tr(Σ) in any loss function


encourages low rank solutions. Figure 8: Computation graph for singular regularizer

ñ Question

∂R
Find the derivative .
∂W

Automatic Differentiation Problems v § } 10


Gradient propagation through the SVD

Suppose, we have the rectangular matrix W ∈ Rm×n , which has a singular value
decomposition:

W = U ΣV T , U T U = I, V T V = I, Σ = diag(σ1 , . . . , σmin(m,n) )

1. Similarly to the previous example:

W = U ΣV T
dW = dU ΣV T + U dΣV T + U ΣdV T
U T dW V = U T dU ΣV T V + U T U dΣV T V + U T U ΣdV T V
U T dW V = U T dU Σ + dΣ + ΣdV T V

Automatic Differentiation Problems v § } 11


Gradient propagation through the SVD

2. Note, that U T U = I → dU T U + U T dU = 0. But also dU T U = (U T dU )T ,


which actually involves, that the matrix U T dU is antisymmetric:

(U T dU )T + U T dU = 0 → diag(U T dU ) = (0, . . . , 0)
The same logic could be applied to the matrix V and

diag(dV T V ) = (0, . . . , 0)

Automatic Differentiation Problems v § } 12


Gradient propagation through the SVD

2. Note, that U T U = I → dU T U + U T dU = 0. But also dU T U = (U T dU )T ,


which actually involves, that the matrix U T dU is antisymmetric:

(U T dU )T + U T dU = 0 → diag(U T dU ) = (0, . . . , 0)
The same logic could be applied to the matrix V and

diag(dV T V ) = (0, . . . , 0)
3. At the same time, the matrix dΣ is diagonal, which means (look at the 1.)
that

diag(U T dW V ) = dΣ
Here on both sides, we have diagonal matrices.

Automatic Differentiation Problems v § } 12


Gradient propagation through the SVD

4. Now, we can decompose the differential of the loss function as a function of


Σ - such problems arise in ML problems, where we need to restrict the
matrix rank:

∂L
D E
dL = , dΣ
∂Σ
∂L
D E
= , diag(U T dW V )
∂Σ
 T 
∂L
= tr diag(U T dW V )
∂Σ

Automatic Differentiation Problems v § } 13


Gradient propagation through the SVD

5. As soon as we have diagonal matrices inside the product, the trace of the
diagonal part of the matrix will be equal to the trace of the whole matrix:

 
∂L T
dL = tr diag(U T dW V )
∂Σ
 
∂L T T
= tr U dW V
∂Σ
∂L T
D E
= , U dW V
∂Σ
∂L T
D E
= U V , dW
∂Σ

Automatic Differentiation Problems v § } 14


Gradient propagation through the SVD

6. Finally, using another parametrization of the differential

∂L T ∂L
D E D E
U V , dW = , dW
∂Σ ∂W
∂L ∂L T
=U V ,
∂W ∂Σ
∂L ∂L
This nice result allows us to connect the gradients and .
∂W ∂Σ

Automatic Differentiation Problems v § } 15


Computation experiment with JAX

Let’s make sure numerically that we have correctly calculated the derivatives in problems 2-3 3

Automatic Differentiation Problems v § } 16


Feedforward Architecture
Forward pass

Backward pass

Figure 9: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The activations
marked with an f . The gradient of the loss with respect to the activations and parameters marked with b.

Gradient checkpointing v § } 17
Feedforward Architecture
Forward pass

Backward pass

Figure 9: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The activations
marked with an f . The gradient of the loss with respect to the activations and parameters marked with b.

, Important

The results obtained for the f nodes are needed to compute the b nodes.

Gradient checkpointing v § } 17
Vanilla backpropagation

Figure 10: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

Gradient checkpointing v § } 18
Vanilla backpropagation

Figure 10: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• All activations f are kept in memory after the forward pass.

Gradient checkpointing v § } 18
Vanilla backpropagation

Figure 10: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• All activations f are kept in memory after the forward pass.

Gradient checkpointing v § } 18
Vanilla backpropagation

Figure 10: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• All activations f are kept in memory after the forward pass.

• Optimal in terms of computation: it only computes each node once.

Gradient checkpointing v § } 18
Vanilla backpropagation

Figure 10: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• All activations f are kept in memory after the forward pass.

• Optimal in terms of computation: it only computes each node once.

Gradient checkpointing v § } 18
Vanilla backpropagation

Figure 10: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• All activations f are kept in memory after the forward pass.

• Optimal in terms of computation: it only computes each node once.

• High memory usage. The memory usage grows linearly with the number of layers in the neural network.

Gradient checkpointing v § } 18
Memory poor backpropagation

Figure 11: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

Gradient checkpointing v § } 19
Memory poor backpropagation

Figure 11: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• Each activation f is recalculated as needed.

Gradient checkpointing v § } 19
Memory poor backpropagation

Figure 11: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• Each activation f is recalculated as needed.

Gradient checkpointing v § } 19
Memory poor backpropagation

Figure 11: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• Each activation f is recalculated as needed.

• Optimal in terms of memory: there is no need to store all activations in memory.

Gradient checkpointing v § } 19
Memory poor backpropagation

Figure 11: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• Each activation f is recalculated as needed.

• Optimal in terms of memory: there is no need to store all activations in memory.

Gradient checkpointing v § } 19
Memory poor backpropagation

Figure 11: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• Each activation f is recalculated as needed.

• Optimal in terms of memory: there is no need to store all activations in memory.

• Computationally inefficient. The number of node evaluations scales with n2 , whereas it vanilla backprop
scaled as n: each of the n nodes is recomputed on the order of n times.

Gradient checkpointing v § } 19
Checkpointed backpropagation
checkpoint

Figure 12: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

Gradient checkpointing v § } 20
Checkpointed backpropagation
checkpoint

Figure 12: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• Trade-off between the vanilla and memory poor approaches. The strategy is to mark a subset of the neural net
activations as checkpoint nodes, that will be stored in memory.

Gradient checkpointing v § } 20
Checkpointed backpropagation
checkpoint

Figure 12: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• Trade-off between the vanilla and memory poor approaches. The strategy is to mark a subset of the neural net
activations as checkpoint nodes, that will be stored in memory.

Gradient checkpointing v § } 20
Checkpointed backpropagation
checkpoint

Figure 12: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• Trade-off between the vanilla and memory poor approaches. The strategy is to mark a subset of the neural net
activations as checkpoint nodes, that will be stored in memory.

• Faster recalculation of activations f . We only need to recompute the nodes between a b node and the
last checkpoint preceding it when computing that b node during backprop.

Gradient checkpointing v § } 20
Checkpointed backpropagation
checkpoint

Figure 12: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• Trade-off between the vanilla and memory poor approaches. The strategy is to mark a subset of the neural net
activations as checkpoint nodes, that will be stored in memory.

• Faster recalculation of activations f . We only need to recompute the nodes between a b node and the
last checkpoint preceding it when computing that b node during backprop.

Gradient checkpointing v § } 20
Checkpointed backpropagation
checkpoint

Figure 12: Computation graph for obtaining gradients for a simple feed-forward neural network with n layers. The purple color
indicates nodes that are stored in memory.

• Trade-off between the vanilla and memory poor approaches. The strategy is to mark a subset of the neural net
activations as checkpoint nodes, that will be stored in memory.

• Faster recalculation of activations f . We only need to recompute the nodes between a b node and the
last checkpoint preceding it when computing that b node during backprop.

• Memory consumption depends on the number of checkpoints. More effective then vanilla approach.

Gradient checkpointing v § } 20
Gradient checkpointing visualization

The animated visualization of the above approaches §


An example of using a gradient checkpointing §

Gradient checkpointing v § } 21
1
Hutchinson Trace Estimation
This example illustrates the estimation the Hessian trace of a neural network using Hutchinson’s method, which is an
algorithm to obtain such an estimate from matrix-vector products:
Let X ∈ Rd×d and v ∈ Rd be a random vector such that E[vv T ] = I. Then,

V
1 X T
Tr(X) = E[v T Xv] = vi Xvi .
V
i=1

An example of using Hutchinson Trace


Estimation §

Figure 13: Multiple runs of the Hutchinson trace estimate, initialized at


different random seeds.
1
A stochastic estimator of the trace of the influence matrix for Laplacian smoothing splines - M.F. Hutchinson, 1990
Gradient checkpointing v § } 22

You might also like