G Jax
G Jax
G Jax
The below code demonstrates the grad function's automatic Repository github.com
differentiation. /google/jax (htt
ps://github.co
1 # imports m/google/jax)
2 from jax import grad
3 import jax.numpy as jnp Written in Python, C++
4
5 # define the logistic function Operating system Linux, macOS,
6 def logistic(x):
Windows
7 return jnp.exp(x) / (jnp.exp(x) + 1)
8 Platform Python,
9 # obtain the gradient function of the logistic
function NumPy
10 grad_logistic = grad(logistic)
11 Size 9.0 MB
12 # evaluate the gradient of the logistic function at x
= 1 Type Machine
13 grad_log_out = grad_logistic(1.0)
learning
14 print(grad_log_out)
License Apache 2.0
The below code demonstrates the jit function's optimization through fusion.
1 # imports
2 from jax import jit
3 import jax.numpy as jnp
4
5 # define the cube function
6 def cube(x):
7 return x * x * x
8
9 # generate data
10 x = jnp.ones((10000, 10000))
11
12 # create the jit version of the cube function
13 jit_cube = jit(cube)
14
15 # apply the cube and jit_cube functions to the same data for speed comparison
16 cube(x)
17 jit_cube(x)
The computation time for jit_cube (line no.17) should be noticeably shorter than that for cube (line no.16).
Increasing the values on line no. 7, will increase the difference.
vmap
The below code demonstrates the vmap function's vectorization.
1 # imports
2 from functools import partial
3 from jax import vmap
4 import jax.numpy as jnp
5
6 # define function
7 def grads(self, inputs):
8 in_grad_partial = partial(self._net_grads, self._net_params)
9 grad_vmap = jax.vmap(in_grad_partial)
10 rich_grads = grad_vmap(inputs)
11 flat_grads = np.asarray(self._flatten_batch(rich_grads))
12 assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
13 return flat_grads
The GIF on the right of this section illustrates the notion of vectorized addition.
pmap
The below code demonstrates the pmap function's parallelization
for matrix multiplication.
[1.1566595 1.1805978]
Flax, a high level neural network library initially developed by Google Brain.[6]
Haiku, an object-oriented library for neural networks developed by DeepMind.[7]
Equinox, a library that revolves around the idea of representing parameterised functions
(including neural networks) as PyTrees. It was created by Patrick Kidger.[8]
Optax, a library for gradient processing and optimisation developed by DeepMind.[9]
RLax, a library for developing reinforcement learning agents developed by DeepMind.[10]
See also
NumPy
TensorFlow
PyTorch
CUDA
Automatic differentiation
Just-in-time compilation
Vectorization
Automatic parallelization
External links
Documentationː jax.readthedocs.io (https://jax.readthedocs.io/)
Colab (Jupyter/iPython) Quickstart Guideː colab.research.google.com/github/google/jax/blob
/main/docs/notebooks/quickstart.ipynb (https://colab.research.google.com/github/google/jax/
blob/main/docs/notebooks/quickstart.ipynb)
TensorFlow's XLAː www.tensorflow.org/xla (https://www.tensorflow.org/xla) (Accelerated
Linear Algebra)
Intro to JAX: Accelerating Machine Learning research (https://www.youtube.com/watch?v=W
dTeDXsOSj4) on YouTube
Original paperː mlsys.org/Conferences/doc/2018/146.pdf (https://mlsys.org/Conferences/do
c/2018/146.pdf)
References
1. Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris;
MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne,
Skye; Zhang, Qiao (2022-06-18), "JAX: Autograd and XLA" (https://web.archive.org/web/202
20618205214/https://github.com/google/jax), Astrophysics Source Code Library, Google,
Bibcode:2021ascl.soft11002B (https://ui.adsabs.harvard.edu/abs/2021ascl.soft11002B),
archived from the original (https://github.com/google/jax) on 2022-06-18, retrieved
2022-06-18
2. Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). "Compiling machine
learning programs via high-level tracing" (https://mlsys.org/Conferences/doc/2018/146.pdf)
(PDF). MLsys: 1–3. Archived (https://web.archive.org/web/20220621153349/https://mlsys.or
g/Conferences/doc/2018/146.pdf) (PDF) from the original on 2022-06-21.
3. "Using JAX to accelerate our research" (https://www.deepmind.com/blog/using-jax-to-accele
rate-our-research). www.deepmind.com. Archived (https://web.archive.org/web/2022061820
5746/https://www.deepmind.com/blog/using-jax-to-accelerate-our-research) from the original
on 2022-06-18. Retrieved 2022-06-18.
4. Lynley, Matthew. "Google is quietly replacing the backbone of its AI product strategy after its
last big push for dominance got overshadowed by Meta" (https://web.archive.org/web/20220
621143905/https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-
meta-ai-2022-6). Business Insider. Archived from the original (https://www.businessinsider.c
om/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6) on 2022-06-21. Retrieved
2022-06-21.
5. "Why is Google's JAX so popular?" (https://analyticsindiamag.com/why-is-googles-jax-so-po
pular/). Analytics India Magazine. 2022-04-25. Archived (https://web.archive.org/web/202206
18210503/https://analyticsindiamag.com/why-is-googles-jax-so-popular/) from the original
on 2022-06-18. Retrieved 2022-06-18.
6. Flax: A neural network library and ecosystem for JAX designed for flexibility (https://github.co
m/google/flax), Google, 2022-07-29, retrieved 2022-07-29
7. Haiku: Sonnet for JAX (https://github.com/deepmind/dm-haiku), DeepMind, 2022-07-29,
retrieved 2022-07-29
8. Kidger, Patrick (2022-07-29), Equinox (https://github.com/patrick-kidger/equinox), retrieved
2022-07-29
9. Optax (https://github.com/deepmind/optax), DeepMind, 2022-07-28, retrieved 2022-07-29
10. RLax (https://github.com/deepmind/rlax), DeepMind, 2022-07-29, retrieved 2022-07-29