Google JAX Cookbook: Perform machine learning and numerical computing with combined capabilities of TensorFlow and NumPy
By Zephyr Quent
()
About this ebook
This is the practical, solution-oriented book for every data scientists, machine learning engineers, and AI engineers to utilize the most of Google JAX for efficient and advanced machine learning. It covers essential tasks, troubleshooting scenarios, and optimization techniques to address common challenges encountered while working with JAX
Related to Google JAX Cookbook
Related ebooks
Google JAX Cookbook Rating: 5 out of 5 stars5/5Google JAX Essentials: A quick practical learning of blazing-fast library for machine learning and deep learning projects Rating: 0 out of 5 stars0 ratingsPython Data Science Cookbook Rating: 0 out of 5 stars0 ratingsFast Data Processing with Spark 2 - Third Edition Rating: 0 out of 5 stars0 ratingsDeep Learning for Data Architects: Unleash the power of Python's deep learning algorithms (English Edition) Rating: 0 out of 5 stars0 ratingsContemporary Machine Learning Methods: Harnessing Scikit-Learn and TensorFlow Rating: 0 out of 5 stars0 ratingsPython Machine Learning By Example Rating: 4 out of 5 stars4/5Python for AI: Applying Machine Learning in Everyday Projects Rating: 0 out of 5 stars0 ratingsPython Feature Engineering Cookbook: A complete guide to crafting powerful features for your machine learning models Rating: 0 out of 5 stars0 ratingsPython Machine Learning By Example: Unlock machine learning best practices with real-world use cases Rating: 0 out of 5 stars0 ratingsMachine Learning with Spark and Python: Essential Techniques for Predictive Analytics Rating: 0 out of 5 stars0 ratingsMachine Learning For Dummies Rating: 4 out of 5 stars4/5Unleashing the Power of Data: Innovative Data Mining with Python Rating: 0 out of 5 stars0 ratingsR Data Structures and Algorithms Rating: 0 out of 5 stars0 ratingsMachine Learning Upgrade: A Data Scientist's Guide to MLOps, LLMs, and ML Infrastructure Rating: 0 out of 5 stars0 ratingsData Science Mastery: From Beginner to Expert in Big Data Analytics Rating: 0 out of 5 stars0 ratingsData Scientist Roadmap Rating: 5 out of 5 stars5/5Machine Learning With Python Programming : 2023 A Beginners Guide Rating: 2 out of 5 stars2/5Mastering Deep Learning with Keras: From Basics to Expert Proficiency Rating: 0 out of 5 stars0 ratingsComprehensive Machine Learning Techniques: A Guide for the Experienced Analyst Rating: 0 out of 5 stars0 ratingsScikit-Learn Unleashed: A Comprehensive Guide to Machine Learning with Python Rating: 0 out of 5 stars0 ratingsMachine Learning Fundamentals: Concepts, Models, and Applications Rating: 0 out of 5 stars0 ratingsApplied Deep Learning on Graphs: Leverage graph data for business applications using specialized deep learning architectures Rating: 0 out of 5 stars0 ratingsApache Spark for Machine Learning: Build and deploy high-performance big data AI solutions for large-scale clusters Rating: 0 out of 5 stars0 ratings
Intelligence (AI) & Semantics For You
Creating Online Courses with ChatGPT | A Step-by-Step Guide with Prompt Templates Rating: 4 out of 5 stars4/5Generative AI For Dummies Rating: 2 out of 5 stars2/5Artificial Intelligence: A Guide for Thinking Humans Rating: 4 out of 5 stars4/5Writing AI Prompts For Dummies Rating: 0 out of 5 stars0 ratingsChatGPT Millionaire: Work From Home and Make Money Online, Tons of Business Models to Choose from Rating: 5 out of 5 stars5/5Mastering ChatGPT: 21 Prompts Templates for Effortless Writing Rating: 4 out of 5 stars4/5100M Offers Made Easy: Create Your Own Irresistible Offers by Turning ChatGPT into Alex Hormozi Rating: 5 out of 5 stars5/5The Secrets of ChatGPT Prompt Engineering for Non-Developers Rating: 5 out of 5 stars5/5AI for Educators: AI for Educators Rating: 3 out of 5 stars3/5Midjourney Mastery - The Ultimate Handbook of Prompts Rating: 5 out of 5 stars5/5AI Money Machine: Unlock the Secrets to Making Money Online with AI Rating: 5 out of 5 stars5/5The ChatGPT Revolution: How to Simplify Your Work and Life Admin with AI Rating: 0 out of 5 stars0 ratings80 Ways to Use ChatGPT in the Classroom Rating: 5 out of 5 stars5/5Chat-GPT Income Ideas: Pioneering Monetization Concepts Utilizing Conversational AI for Profitable Ventures Rating: 4 out of 5 stars4/53550+ Most Effective ChatGPT Prompts Rating: 0 out of 5 stars0 ratingsThe Roadmap to AI Mastery: A Guide to Building and Scaling Projects Rating: 3 out of 5 stars3/5Artificial Intelligence For Dummies Rating: 3 out of 5 stars3/5Coding with AI For Dummies Rating: 1 out of 5 stars1/5A Quickstart Guide To Becoming A ChatGPT Millionaire: The ChatGPT Book For Beginners (Lazy Money Series®) Rating: 4 out of 5 stars4/5THE CHATGPT MILLIONAIRE'S HANDBOOK: UNLOCKING WEALTH THROUGH AI AUTOMATION Rating: 5 out of 5 stars5/5The Ultimate ChatGPT Handbook Rating: 0 out of 5 stars0 ratingsChatGPT For Fiction Writing: AI for Authors Rating: 5 out of 5 stars5/5Thinking in Algorithms: Strategic Thinking Skills, #2 Rating: 4 out of 5 stars4/5
Reviews for Google JAX Cookbook
0 ratings0 reviews
Book preview
Google JAX Cookbook - Zephyr Quent
Google JAX Cookbook
Perform machine learning and numerical computing with combined capabilities of TensorFlow and NumPy
Zephyr Quent
Preface
This is the practical, solution-oriented book data scientists, machine learning engineers, and AI engineers need to make the most of Google JAX for efficient and advanced machine learning. It covers essential tasks, troubleshooting scenarios, and optimization techniques to address common challenges encountered while working with JAX across machine learning and numerical computing projects.
The book starts with the move from NumPy to JAX. It introduces the best ways to speed up computations, handle data types, generate random numbers, and perform in-place operations. It then shows you how to use profiling techniques to monitor computation time and device memory, helping you to optimize training and performance. The debugging section provides clear and effective strategies for resolving common runtime issues, including shape mismatches, NaNs, and control flow errors.
The book goes on to show you how to master Pytrees for data manipulation, integrate external functions through the Foreign Function Interface (FFI), and utilize advanced serialization and type promotion techniques for stable computations. If you want to optimize training processes, this book has you covered. It includes recipes for efficient data loading, building custom neural networks, implementing mixed precision, and tracking experiments with Penzai. You'll learn how to visualize model performance and monitor metrics to assess training progress effectively. The recipes in this book tackle real-world scenarios and give users the power to fix issues and fine-tune models quickly.
In this book you will learn how to:
Get your calculations done faster by moving from NumPy to JAX's optimized framework.
Make your training pipelines more efficient by profiling how long things take and how much memory they use.
Use debugging techniques to fix runtime issues like shape mismatches and numerical instability.
Get to grips with Pytrees for managing complex, nested data structures across various machine learning tasks.
Use JAX's Foreign Function Interface (FFI) to bring in external functions and give your computational capabilities a boost.
Take advantage of mixed-precision training to speed up neural network computations without sacrificing model accuracy.
Keep your experiments on track with Penzai. This lets you reproduce results and monitor key metrics.
Use advanced visualization techniques, like confusion matrices and learning curves, to make model evaluation more effective.
Create your own neural networks and optimizers directly in JAX so you have full control of the architecture.
Use serialization techniques to save, load, and transfer models and training checkpoints efficiently.
Prologue
As a machine learning engineer, I often found myself facing the limitations of NumPy. It was great for numerical computing, but when it came to scaling up models or training complex neural networks, it just couldn't keep up. The issue wasn't just about speed, though that was certainly a factor. I also had trouble with automatic differentiation and GPU acceleration, which made me look for alternatives that could handle the demands of modern machine learning workflows.
That's when I discovered Google JAX. JAX offered the potential for faster execution through just-in-time (JIT) compilation, and there were more benefits to come. It had built-in support for automatic differentiation and seamless integration with hardware acceleration, which made it an attractive option. As with any tool, there was a learning curve and a few hurdles to overcome. Moving from the familiar world of NumPy meant a change in mindset and code structure. A lot of people have trouble with this change, and I was no exception.
I was inspired to write the Google JAX Cookbook because I wanted to share practical solutions based on my own experiences with these challenges. I wanted to create something that was more than just a theoretical reference—I wanted to put together a hands-on, actionable book filled with practical recipes. Each recipe in this book tackles a specific issue, offers a clear solution, and provides enough context to make it useful without overwhelming you. I've put together the book in a way that takes you through different situations you might come across, from speeding up basic numerical operations to troubleshooting training issues.
In the book, I cover some of the challenges of profiling computation and memory, debugging runtime errors, and optimizing neural networks for different hardware setups. You'll also get tips on how to manage data structures with Pytrees, use JAX's Foreign Function Interface (FFI) to integrate external libraries, and even set up experiment tracking with Penzai to keep your projects organized. Each chapter builds on the last, gradually increasing in complexity while making sure you don't get lost along the way. I've kept a consistent focus on solving real-world problems because, at the end of the day, that's what matters.
One of the main things I focus on is showing you how to use JAX to your advantage. JAX's strength is its flexibility. Once you start to understand its nuances, you'll see it opens up a world of possibilities. You'll find tips on everything from switching from high-precision to mixed precision to speed up training to managing memory more effectively to prevent out-of-memory errors. All the techniques you need to get started right away.
I've written this book with data scientists, machine learning engineers, and AI practitioners in mind. If you're looking for ways to make your workflows faster, more efficient, and less prone to errors, this book is a great resource to have on hand. Together, we'll figure out how to use JAX, fix any problems that come up, and see what's possible with advanced machine learning.
Copyright © 2024 by GitforGits
All rights reserved. This book is protected under copyright laws and no part of it may be reproduced or transmitted in any form or by any means, electronic or mechanical, including photocopying, recording, or by any information storage and retrieval system, without the prior written permission of the publisher. Any unauthorized reproduction, distribution, or transmission of this work may result in civil and criminal penalties and will be dealt with in the respective jurisdiction at anywhere in India, in accordance with the applicable copyright laws.
Published by: GitforGits
Publisher: Sonal Dhandre
www.gitforgits.com
Printed in India
First Printing: October 2024
Cover Design by: Kitten Publishing
For permission to use material from this book, please contact GitforGits at [email protected].
Content
Preface
GitforGits
Acknowledgement
Chapter 1: Transition NumPy to JAX
Overview
Accelerating NumPy Code with JAX
Setting up Environment
Loading Fashion-MNIST Dataset
Preprocessing Data
Building a Simple Neural Network with NumPy
Defining Forward and Backward Passes
Training Model
Converting NumPy Code to JAX
Handling Unsupported NumPy Functions in JAX
Identifying Unsupported Functions
Replacing ‘np.random.shuffle’
Using JAX's ‘random.permutation’
Replacing unsupported Linear Algebra Functions
Handling In-Place Operations
Dealing with Advanced Indexing and Assignment
Handling Random Number Generation Functions
Implementing Custom Functions
Using Third-Party Libraries Compatible with JAX
Refactoring Code to Align with JAX's Paradigm
Managing Random Number Generation
Understanding JAX's PRNG Keys
Initializing PRNG Key
Generating Random Numbers
Shuffling Data without In-Place Operations
Ensuring Reproducibility in Model Training
Using PRNG Keys in Data Augmentation
Managing PRNG Keys in Loops
Incorporating PRNG Keys with JIT Compilation
Storing and Restoring PRNG Keys
Using ‘jax.random.fold_in’ for Multi-Process Environments
Dealing with In-Place Operations
Understanding In-Place Operations and JAX's Paradigm
Refactored Code using Functional Updates
Handling In-Place Array Modifications
Refactoring Conditional In-Place Updates
Avoiding In-Place Accumulations
Using ‘jax.ops.index_update’ for Indexing Updates
Refactoring Loops with In-Place Modifications
Ensuring Numerical Stability During Transition
Understanding JAX's Type Promotion Semantics
Inspecting Data Types
Converting Data Types Explicitly
Initializing Model Parameters with Consistent Types
Ensuring Consistent Data Types in Computations
Handling Division and Logarithmic Operations
Using JAX's Type Promotion Rules and Monitoring NaNs
Implementing Gradient Clipping
Summary
Chapter 2: Profiling Computation and Device Memory
Overview
Measuring Execution Time of JAX Functions
Identifying Functions to Profile
Using JAX's Built-in Timing Function
Profiling with ‘timeit’ and ‘cProfile’
Using ‘line_profiler’ for Line-by-Line Analysis
Profiling JIT-compiled Functions
Monitoring GPU Memory Usage
Profiling GPU Memory Usage
Preventing JAX from Preallocating GPU Memory
Reducing Data Precision
Adjusting Batch Size
Optimizing Data Loading
Utilizing JAX's ‘jit’ Compilation
Implementing Gradient Checkpointing
Monitoring Memory Allocation with JAX Profiling Tools
Minimizing Data Transfers between Host and Device
Simplifying Model Architecture
Utilizing Mixed Precision Training
Visualizing Computation Graphs
Prepare Sample Data and Parameters
Convert JAX Functions to TensorFlow Functions
Setup TensorBoard Logging
Trace Computation Graph
Launch TensorBoard
Enhance Graph Readability with Named Scopes
Optimizing Batch Sizes for Performance
Understanding Impact of Batch Size
Setting up Experiment
Modifying Data Loader
Implementing Timing and Memory Profiling
Running Training Experiments with Different Batch Sizes
Monitoring GPU Memory Usage
Recording and Analyzing Results
Batch Size
Training Time (s)
Peak GPU Memory Usage (MiB)
Adjusting Learning Rate based on Batch Size
Implementing Gradient Accumulation
Reducing Memory Footprint with Gradient Checkpointing
Understanding Forward Pass Structure
Applying ‘jax.checkpoint’ to Forward Pass
Adjusting Loss Function and Gradient Computation
Training Model with Checkpointing
Checkpointing Groups of Layers
Using ‘jax.remat’ with Custom Policies
Summary
Chapter 3: Debugging Runtime Values and Errors
Overview
Handling Concretization Errors
Understanding Concretization Errors
Replace Python Control Flow with JAX Control Flow Primitives
Use Element-wise Operations for Arrays
Handle Loops with ‘jax.lax.scan’ or ‘jax.lax.fori_loop’
Move Runtime-dependent Logic outside JIT Functions
Use Static Arguments with ‘static_argnums’
Inspecting Intermediate Values in JIT-Compiled Functions
Understanding Why ‘print’ Statements Fail?
Using ‘jax.debug.print’ for Debugging
Applying ‘jax.debug.print’ in Model
Conditional Debugging
Using Host Callbacks with ‘jax.experimental.host_callback’
Dealing with Shape Mismatch Errors
Understanding Tensor Shapes in Model
Verifying Parameter Shapes
Implementing Forward Pass with Shape Checks
Utilizing ‘jax.debug.print’ for Shape Inspection
Checking Input Data Shapes before Training
Handling Broadcasting Issues
Using Explicit Reshaping and Transposing
Testing with Small Batches
Resolving Issues with NaNs in Computations
Identifying Source of NaNs
Checking for NaNs in Activations and Gradients
Handling Numerical Instability in Softmax Function
Applying Gradient Clipping
Adding Regularization
Proper Weight Initialization
Using Alternative Activation Functions
Implementing Batch Normalization
Summary
Chapter 4: Mastering Pytrees for Data Structures
Overview
Manipulating Nested Data with Pytrees
Understanding Pytrees
Using ‘jax.tree_map’ to apply Functions over Pytrees
Working with Multiple Pytrees
Flattening and Unflattening Pytrees
Applying JAX Transformations with Pytrees
Updating Parameters using Gradients
Filtering Leaves based on Conditions
Define Custom PyTree Nodes
Custom Pytrees for User-defined Classes
Define Custom Class
Use JAX Functions
Register Custom Class as a PyTree Node
Integrate with JAX Transformations
Compute Gradients
Combine Multiple Layers into a Model
Use ‘vmap’ and other Transformations
Serializing and Deserializing Pytrees
Understanding Model Parameters as Pytrees
Serializing PyTree to Bytes
Loading and Deserializing PyTree
Handling Custom Classes in Pytrees
Serializing and Deserializing PyTree with Custom Classes
Saving and Loading Checkpoints during Training
Using ‘flax.training.checkpoints’ Module
Filtering Pytrees for Specific Parameters
Defining Filter Function
Traversing Pytree with ‘jax.tree_map’ and Filtering
Applying Updates selectively based on Parameter Names
Integrating with Gradient Computation
Generalizing Filter Function
Using Regular Expressions for Complex Filters
Using ‘optax’ for Parameter Masks
Combining Multiple Masks
Summary
Chapter 5: Exporting and Serialization
Overview
Saving Trained Models for Deployment
Define and Train Model
Save Model Parameters
Save and Load Model for Inference
Deploy Model
Testing Deployment
Checkpointing Long-Running Jobs
Creating Training State
Implementing Training Step
Implementing Checkpointing
Handling Unexpected Interruptions
Customizing Checkpointing Strategies
Testing Checkpoint Restoration
Converting JAX Models to Other Frameworks
Convert JAX Model to TensorFlow
Build TensorFlow Function for Exporting
Convert TensorFlow Model to ONNX
Run Inference with ONNX Model
Compare JAX and ONNX Model Results
Serializing Custom Objects in JAX
Define Custom Objects
Initialize and Train Model
Serialize the Model
Implement Custom Serialization Functions
Load Model Parameters
Handle additional Custom Objects
Serialize Training State
Summary
Chapter 6: Type Promotion Semantics and Mixed Precision
Overview
Controlling Type Promotion in Arithmetic Operations
Understanding JAX's Type Promotion Rules
Explicitly Cast Data Types before Operations
Consistent Data Types in Model Parameters
Casting Input Data and Enforcing Data Types