G Jax

Download as pdf or txt
Download as pdf or txt
You are on page 1of 4

Google JAX

Google JAX is a machine learning framework for transforming


JAX
numerical functions.[1][2][3] It is described as bringing together a
modified version of autograd (https://github.com/HIPS/autograd)
(automatic obtaining of the gradient function through
differentiation of a function) and TensorFlow's XLA (https://www.
tensorflow.org/xla) (Accelerated Linear Algebra). It is designed to
Logo
follow the structure and workflow of NumPy as closely as possible
and works with various existing frameworks such as TensorFlow
and PyTorch.[4][5] The primary functions of JAX are:[1]

1. grad: automatic differentiation


2. jit: compilation
3. vmap: auto-vectorization
4. pmap: SPMD programming
Developer(s) Google
Preview release v0.3.13 /
grad 16 May 2022

The below code demonstrates the grad function's automatic Repository github.com
differentiation. /google/jax (htt
ps://github.co
1 # imports m/google/jax)
2 from jax import grad
3 import jax.numpy as jnp Written in Python, C++
4
5 # define the logistic function Operating system Linux, macOS,
6 def logistic(x):
Windows
7 return jnp.exp(x) / (jnp.exp(x) + 1)
8 Platform Python,
9 # obtain the gradient function of the logistic
function NumPy
10 grad_logistic = grad(logistic)
11 Size 9.0 MB
12 # evaluate the gradient of the logistic function at x
= 1 Type Machine
13 grad_log_out = grad_logistic(1.0)
learning
14 print(grad_log_out)
License Apache 2.0

The final line should outputː Website jax


.readthedocs
0.19661194 .io/en/latest/ (h
ttps://jax.readt
hedocs.io/en/la
jit test/) 

The below code demonstrates the jit function's optimization through fusion.

1 # imports
2 from jax import jit
3 import jax.numpy as jnp
4
5 # define the cube function
6 def cube(x):
7 return x * x * x
8
9 # generate data
10 x = jnp.ones((10000, 10000))
11
12 # create the jit version of the cube function
13 jit_cube = jit(cube)
14
15 # apply the cube and jit_cube functions to the same data for speed comparison
16 cube(x)
17 jit_cube(x)

The computation time for jit_cube (line no.17) should be noticeably shorter than that for cube (line no.16).
Increasing the values on line no. 7, will increase the difference.

vmap
The below code demonstrates the vmap function's vectorization.

1 # imports
2 from functools import partial
3 from jax import vmap
4 import jax.numpy as jnp
5
6 # define function
7 def grads(self, inputs):
8 in_grad_partial = partial(self._net_grads, self._net_params)
9 grad_vmap = jax.vmap(in_grad_partial)
10 rich_grads = grad_vmap(inputs)
11 flat_grads = np.asarray(self._flatten_batch(rich_grads))
12 assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
13 return flat_grads

The GIF on the right of this section illustrates the notion of vectorized addition.

pmap
The below code demonstrates the pmap function's parallelization
for matrix multiplication.

1 # import pmap and random from JAX; import JAX NumPy


2 from jax import pmap, random
3 import jax.numpy as jnp Illustration video of vectorized
4 addition
5 # generate 2 random matrices of dimensions 5000 x
6000, one per device
6 random_keys = random.split(random.PRNGKey(0), 2)
7 matrices = pmap(lambda key: random.normal(key, (5000,
6000)))(random_keys)
8
9 # without data transfer, in parallel, perform a local
matrix multiplication on each CPU/GPU
10 outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
11
12 # without data transfer, in parallel, obtain the mean
for both matrices on each CPU/GPU separately
13 means = pmap(jnp.mean)(outputs)
14 print(means)
The final line should print the valuesː

[1.1566595 1.1805978]

Libraries using Jax


Several python libraries use Jax as a backend, including:

Flax, a high level neural network library initially developed by Google Brain.[6]
Haiku, an object-oriented library for neural networks developed by DeepMind.[7]
Equinox, a library that revolves around the idea of representing parameterised functions
(including neural networks) as PyTrees. It was created by Patrick Kidger.[8]
Optax, a library for gradient processing and optimisation developed by DeepMind.[9]
RLax, a library for developing reinforcement learning agents developed by DeepMind.[10]

See also
NumPy
TensorFlow
PyTorch
CUDA
Automatic differentiation
Just-in-time compilation
Vectorization
Automatic parallelization

External links
Documentationː jax.readthedocs.io (https://jax.readthedocs.io/)
Colab (Jupyter/iPython) Quickstart Guideː colab.research.google.com/github/google/jax/blob
/main/docs/notebooks/quickstart.ipynb (https://colab.research.google.com/github/google/jax/
blob/main/docs/notebooks/quickstart.ipynb)
TensorFlow's XLAː www.tensorflow.org/xla (https://www.tensorflow.org/xla) (Accelerated
Linear Algebra)
Intro to JAX: Accelerating Machine Learning research (https://www.youtube.com/watch?v=W
dTeDXsOSj4) on YouTube
Original paperː mlsys.org/Conferences/doc/2018/146.pdf (https://mlsys.org/Conferences/do
c/2018/146.pdf)

References
1. Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris;
MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne,
Skye; Zhang, Qiao (2022-06-18), "JAX: Autograd and XLA" (https://web.archive.org/web/202
20618205214/https://github.com/google/jax), Astrophysics Source Code Library, Google,
Bibcode:2021ascl.soft11002B (https://ui.adsabs.harvard.edu/abs/2021ascl.soft11002B),
archived from the original (https://github.com/google/jax) on 2022-06-18, retrieved
2022-06-18
2. Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). "Compiling machine
learning programs via high-level tracing" (https://mlsys.org/Conferences/doc/2018/146.pdf)
(PDF). MLsys: 1–3. Archived (https://web.archive.org/web/20220621153349/https://mlsys.or
g/Conferences/doc/2018/146.pdf) (PDF) from the original on 2022-06-21.
3. "Using JAX to accelerate our research" (https://www.deepmind.com/blog/using-jax-to-accele
rate-our-research). www.deepmind.com. Archived (https://web.archive.org/web/2022061820
5746/https://www.deepmind.com/blog/using-jax-to-accelerate-our-research) from the original
on 2022-06-18. Retrieved 2022-06-18.
4. Lynley, Matthew. "Google is quietly replacing the backbone of its AI product strategy after its
last big push for dominance got overshadowed by Meta" (https://web.archive.org/web/20220
621143905/https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-
meta-ai-2022-6). Business Insider. Archived from the original (https://www.businessinsider.c
om/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6) on 2022-06-21. Retrieved
2022-06-21.
5. "Why is Google's JAX so popular?" (https://analyticsindiamag.com/why-is-googles-jax-so-po
pular/). Analytics India Magazine. 2022-04-25. Archived (https://web.archive.org/web/202206
18210503/https://analyticsindiamag.com/why-is-googles-jax-so-popular/) from the original
on 2022-06-18. Retrieved 2022-06-18.
6. Flax: A neural network library and ecosystem for JAX designed for flexibility (https://github.co
m/google/flax), Google, 2022-07-29, retrieved 2022-07-29
7. Haiku: Sonnet for JAX (https://github.com/deepmind/dm-haiku), DeepMind, 2022-07-29,
retrieved 2022-07-29
8. Kidger, Patrick (2022-07-29), Equinox (https://github.com/patrick-kidger/equinox), retrieved
2022-07-29
9. Optax (https://github.com/deepmind/optax), DeepMind, 2022-07-28, retrieved 2022-07-29
10. RLax (https://github.com/deepmind/rlax), DeepMind, 2022-07-29, retrieved 2022-07-29

Retrieved from "https://en.wikipedia.org/w/index.php?title=Google_JAX&oldid=1160082308"

You might also like