0% found this document useful (0 votes)
135 views9 pages

Solving Optimization Problems With JAX by Mazeyar Moeini The Startup Medium

JAX is a Python library that allows for efficient optimization of problems involving linear algebra and matrix methods through features like automatic differentiation, vectorization, and GPU/TPU acceleration. The document demonstrates how to use JAX to solve various types of optimization problems, from single-variable problems solved with gradient descent and Newton's method, to multi-variable and constrained problems using techniques like the Jacobian, Hessian, and Lagrangian multipliers. For complex multi-variable and multi-constrained problems, JAX allows numerical optimization methods to scale effectively to problems that would be difficult to solve by hand.
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
135 views9 pages

Solving Optimization Problems With JAX by Mazeyar Moeini The Startup Medium

JAX is a Python library that allows for efficient optimization of problems involving linear algebra and matrix methods through features like automatic differentiation, vectorization, and GPU/TPU acceleration. The document demonstrates how to use JAX to solve various types of optimization problems, from single-variable problems solved with gradient descent and Newton's method, to multi-variable and constrained problems using techniques like the Jacobian, Hessian, and Lagrangian multipliers. For complex multi-variable and multi-constrained problems, JAX allows numerical optimization methods to scale effectively to problems that would be difficult to solve by hand.
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 9

Solving Optimization Problems with JAX - The Startup - Medium

Mazeyar Moeini

Joseph-Louis Lagrange & Isaac Newton, JAX logo by Google

1 Introduction

What is JAX? As described by the main JAX webpage, JAX is Autograd and XLA, brought together for high-performance machine learning
research. JAX essentially augments the numpy library to create a nouvelle library with Autograd, Vector Mapping (vmap), Just In Time
compilation (JIT), all compiled with Accelerated Linear Algebra (XLA) with Tensor processing unit (TPU) support and much more. With all of
these features, problems that depend on linear algebra and matrix methods can be solved more efficiently. The purpose of this article is to show
that indeed, these features can be used to solve a range of simple to complex optimization problems with matrix methods and to provide an
intuitive understanding of the mathematics and implementation behind the code.

Firstly we will be required to import the JAX libraries and nanargmin/nanargmax from numpy, as they are not implemented in JAX yet. If you
are using Google Colab, there is no installation of JAX required, as JAX is open sourced and maintained by Google.
2 Grad, Jacobians and Vmap

Grad is best used for taking the automatic derivative of a function. It creates a function that evaluates the gradient of a given function. If we called
grad(grad(f)), this would be the second derivative.

Jacobian is best used for taking the automatic derivative of a function with a vector input. We can see that it returns the expected vector from a
circle function.

Even more interesting is how we can compute the Hessian of a function by computing the Jacobian twice; this is what makes JAX so powerful!
We see that the function hessian takes in a function and returns a function as well.

It should be noted that the gradients are computed with automatic differentiation, which is much more accurate and efficient compared to finite
differences.

3 Single Variable Optimization

3.1 Gradient Descent

Let’s imagine that we have the following optimization problem from UC Davis; A rectangular piece of paper is 12 inches in length and six inches
wide. The lower right-hand corner is folded over so as to reach the leftmost edge of the paper, find the minimum length of the resulting crease
where L is the length.
Image by University of California Davis

After doing some trigonometry, we can find the length of the crease with respect to the variable x to be:

To find the minimum we would have to check all the critical points such that L’=0. However, although this is a relatively simple optimization
problem, it would still lead to a messy derivative that requires chain rule and quotient rule. Therefore, as these problems only become more
complex, it would be wise to find numerical methods to solve them. Jumping over to JAX, we can define the functions in python.

Then, using grad(L) we can find the derivative of L and minimize this using stepwise gradient descent.

We can see how simple things become with JAX; the actual optimization happens with 6 lines of code! Notice how the first vmap is used in each
epoch to map the minGD function over the whole domain, then it’s used to map the domain with the objective function L to find the objective
minimum and argmin.

The numeric answer gives a 0.001851% error from the actual answer which is 9*sqrt(3)/2, the error is acceptable given that the true value is an
irrational number, to begin with.

3.2 Newton’s Method

The same problem can be solved using Newton’s Method. Usually, Newton’s Method is used for solving a function that is equal to zero such as
f(x)=x²−2=0, in the form of:
This can easily be used for optimization given that we search for f’(x)=0

Newton’s Method for optimization can easily be implemented with JAX.

Notice how easily L’’ is calculated in line 2 of the code.

Newton’s Method has the added advantage of the error being squared in each step.

4 Multivariable Optimization

4.1 The Jacobian

In multivariable problems, we define functions such that f(X),

X=[x0,x1,x2…,xn] . When the number of variables increases, we can no longer use the normal derivative; it requires the Jacobian also written as
∇f.
A Jacobian is a derivative of multivariable function, therefore, it captures how each variable affects a function. Since these are the first
derivatives, we can again use these to optimize a multivariable function.

Now, to implement this with JAX is just as simple as the single variable case. We will optimize the following function:

Notice again how easily JAX allows us to calculate the Jacobian.

Similar to last time, once we have the optimization function we can run it through a loop.

Then we check for the results.

4.2 The Hessian

In multivariable problems we define functions such that f(X),X = [x0,x1,x2…,xn]. Previously we defined the Jacobian (∇f). The Hessian is just
(∇(∇f)) or ∇’’f which requires the differentiation of each function in the Jacobian to all variables, thus increasing the dimension.
To use the Hessian in optimization, it is really similar to Newton’s Method. In fact, it is analogous.

We can observe, where it’s not possible to divide by a matrix, we multiply by its inverse. There is a mathematical explanation for this using the
quadratic term of a Taylor expansion, however, it is too lengthy to explain. Again using the Autograd library it is incredibly easy to calculate the
Hessian.

5 Multivariable Constrained Optimization

Multivariable constrained optimization uses the same techniques as before but requires a different way of framing the problem. For constrained
optimization we will use Lagrangian multipliers. The classic Lagrange equation requires solving for ∇f = λ∇g. However, computers have no way
of symbolically solving this. Rather, we can rewrite the equation as ∇f−λ∇g=0 which is now an unconstrained optimization problem.
Just like the other optimization problems, we have a function that needs to be solved at zero ∇ L=0. Note the solving ∇ L= 0 is no different than
solving for systems of nonlinear equations. Our final iterative equation will look similar.

The reason for the Hessian being involved again is due to minimizing L and solving for ∇L=0 being the same statement. Also, when using
Lagrangian multipliers we have to introduce a new variable λ in the code, L(X) will take in X where X=[x0, x1, λ].

Let's say we have the objective function f(X) and the constraint g(X), in the code λ is l[3].

The correct minimum is -8, the argmin should be (sqrt(2),−1), and since we included the λ in our calculation we find the Lagrangian multiplier is
−4.0.

6 Three Variable Multivariable Constrained Optimization

Problems in real life usually have more than two variables to be optimized and optimization hyperparameters need to be fine-tuned. As the
complexity of optimization problems increases, other methods should be considered. For now we can use the models from the previous section
and just increase the number of variables. Luckily, JAX will automatically adjust for this, we just need to adjust the L function in the code.

Let’s attempt to solve a problem with real-life applications found from Paul’s Online Notes; Find the dimensions of the box with the largest
volume if the total surface area is 64cm². Our objective function is f(x) = x0*x1*x2 the constraint is g(x) = 2*x0*x1 + 2*x1*x2 + 2*x0*x2 − 64.
First we have to define the functions, then the only thing that we have to change is the index of the list feeding into Lagrange.

This part of the code stays exactly the same except we add a learning rate of 0.1 to gain greater accuracy. We might also have to increase the total
epochs.

The real answer is sqrt(32/3)³ ≈ 34.837187 the length of each side should be sqrt(32/3) ≈ 3.265985 the is almost calculation is perfect as the
errors are negligible in real life. It’s important to note that without the learning rate, optimization is unlikely and accuracy was increased by
doubling the number of epochs. Hopefully it is now obvious how more variables can be included in the optimization model.

7 Multivariable MultiConstrained Optimization

In the final part of this tutorial we will look at one of the most advanced types of optimization problems, multivariable multiconstrained
optimization problems. Some of the problems, in the beginning, are admittedly better solved by hand. However as complexity increases, other
numerical methods might be needed. Gradient Descent, no matter how many epochs and hyperparameters, can never 100% guarantee the best
result but it always better than a random guess.

Let’s start by trying to maximize the object function f(x0, x1) with the constraints g(x0, x1) and h(x0, x1).

More problems like this can be found at Duke University. The general form of the Lagrangian function can be written such that the Jacobian of
the objective minus each constraint function Jacobian multiplied by a respective lambda is equal to zero.
We see now that we have to define all three functions. Note that l[2] = λ1 and l[3] = λ2.

Once again these are the expected values.

Hopefully, by now you found a new interest in optimization problems and more importantly realized how the features JAX offers make solving
such problems easier. Furthermore, machine learning libraries offer fast and reliable tools for problem-solving that can be used outside the
machine learning domain.

You might also like