Explore 1.5M+ audiobooks & ebooks free for days

Only $12.99 CAD/month after trial. Cancel anytime.

Google JAX Cookbook: Perform machine learning and numerical computing with combined capabilities of TensorFlow and NumPy
Google JAX Cookbook: Perform machine learning and numerical computing with combined capabilities of TensorFlow and NumPy
Google JAX Cookbook: Perform machine learning and numerical computing with combined capabilities of TensorFlow and NumPy
Ebook280 pages2 hours

Google JAX Cookbook: Perform machine learning and numerical computing with combined capabilities of TensorFlow and NumPy

Rating: 0 out of 5 stars

()

Read preview

About this ebook

This is the practical, solution-oriented book for every data scientists, machine learning engineers, and AI engineers to utilize the most of Google JAX for efficient and advanced machine learning. It covers essential tasks, troubleshooting scenarios, and optimization techniques to address common challenges encountered while working with JAX

LanguageEnglish
PublisherGitforGits
Release dateOct 30, 2024
ISBN9788197950490
Google JAX Cookbook: Perform machine learning and numerical computing with combined capabilities of TensorFlow and NumPy

Related to Google JAX Cookbook

Related ebooks

Intelligence (AI) & Semantics For You

View More

Reviews for Google JAX Cookbook

Rating: 0 out of 5 stars
0 ratings

0 ratings0 reviews

What did you think?

Tap to rate

Review must be at least 10 words

    Book preview

    Google JAX Cookbook - Zephyr Quent

    Google JAX Cookbook

    Perform machine learning and numerical computing with combined capabilities of TensorFlow and NumPy

    Zephyr Quent

    Preface

    This is the practical, solution-oriented book data scientists, machine learning engineers, and AI engineers need to make the most of Google JAX for efficient and advanced machine learning. It covers essential tasks, troubleshooting scenarios, and optimization techniques to address common challenges encountered while working with JAX across machine learning and numerical computing projects.

    The book starts with the move from NumPy to JAX. It introduces the best ways to speed up computations, handle data types, generate random numbers, and perform in-place operations. It then shows you how to use profiling techniques to monitor computation time and device memory, helping you to optimize training and performance. The debugging section provides clear and effective strategies for resolving common runtime issues, including shape mismatches, NaNs, and control flow errors.

    The book goes on to show you how to master Pytrees for data manipulation, integrate external functions through the Foreign Function Interface (FFI), and utilize advanced serialization and type promotion techniques for stable computations. If you want to optimize training processes, this book has you covered. It includes recipes for efficient data loading, building custom neural networks, implementing mixed precision, and tracking experiments with Penzai. You'll learn how to visualize model performance and monitor metrics to assess training progress effectively. The recipes in this book tackle real-world scenarios and give users the power to fix issues and fine-tune models quickly.

    In this book you will learn how to:

    Get your calculations done faster by moving from NumPy to JAX's optimized framework.

    Make your training pipelines more efficient by profiling how long things take and how much memory they use.

    Use debugging techniques to fix runtime issues like shape mismatches and numerical instability.

    Get to grips with Pytrees for managing complex, nested data structures across various machine learning tasks.

    Use JAX's Foreign Function Interface (FFI) to bring in external functions and give your computational capabilities a boost.

    Take advantage of mixed-precision training to speed up neural network computations without sacrificing model accuracy.

    Keep your experiments on track with Penzai. This lets you reproduce results and monitor key metrics.

    Use advanced visualization techniques, like confusion matrices and learning curves, to make model evaluation more effective.

    Create your own neural networks and optimizers directly in JAX so you have full control of the architecture.

    Use serialization techniques to save, load, and transfer models and training checkpoints efficiently.

    Prologue

    As a machine learning engineer, I often found myself facing the limitations of NumPy. It was great for numerical computing, but when it came to scaling up models or training complex neural networks, it just couldn't keep up. The issue wasn't just about speed, though that was certainly a factor. I also had trouble with automatic differentiation and GPU acceleration, which made me look for alternatives that could handle the demands of modern machine learning workflows.

    That's when I discovered Google JAX. JAX offered the potential for faster execution through just-in-time (JIT) compilation, and there were more benefits to come. It had built-in support for automatic differentiation and seamless integration with hardware acceleration, which made it an attractive option. As with any tool, there was a learning curve and a few hurdles to overcome. Moving from the familiar world of NumPy meant a change in mindset and code structure. A lot of people have trouble with this change, and I was no exception.

    I was inspired to write the Google JAX Cookbook because I wanted to share practical solutions based on my own experiences with these challenges. I wanted to create something that was more than just a theoretical reference—I wanted to put together a hands-on, actionable book filled with practical recipes. Each recipe in this book tackles a specific issue, offers a clear solution, and provides enough context to make it useful without overwhelming you. I've put together the book in a way that takes you through different situations you might come across, from speeding up basic numerical operations to troubleshooting training issues.

    In the book, I cover some of the challenges of profiling computation and memory, debugging runtime errors, and optimizing neural networks for different hardware setups. You'll also get tips on how to manage data structures with Pytrees, use JAX's Foreign Function Interface (FFI) to integrate external libraries, and even set up experiment tracking with Penzai to keep your projects organized. Each chapter builds on the last, gradually increasing in complexity while making sure you don't get lost along the way. I've kept a consistent focus on solving real-world problems because, at the end of the day, that's what matters.

    One of the main things I focus on is showing you how to use JAX to your advantage. JAX's strength is its flexibility. Once you start to understand its nuances, you'll see it opens up a world of possibilities. You'll find tips on everything from switching from high-precision to mixed precision to speed up training to managing memory more effectively to prevent out-of-memory errors. All the techniques you need to get started right away.

    I've written this book with data scientists, machine learning engineers, and AI practitioners in mind. If you're looking for ways to make your workflows faster, more efficient, and less prone to errors, this book is a great resource to have on hand. Together, we'll figure out how to use JAX, fix any problems that come up, and see what's possible with advanced machine learning.

    Copyright © 2024 by GitforGits

    All rights reserved. This book is protected under copyright laws and no part of it may be reproduced or transmitted in any form or by any means, electronic or mechanical, including photocopying, recording, or by any information storage and retrieval system, without the prior written permission of the publisher. Any unauthorized reproduction, distribution, or transmission of this work may result in civil and criminal penalties and will be dealt with in the respective jurisdiction at anywhere in India, in accordance with the applicable copyright laws.

    Published by: GitforGits

    Publisher: Sonal Dhandre

    www.gitforgits.com

    [email protected]

    Printed in India

    First Printing: October 2024

    Cover Design by: Kitten Publishing

    For permission to use material from this book, please contact GitforGits at [email protected].

    Content

    Preface

    GitforGits

    Acknowledgement

    Chapter 1: Transition NumPy to JAX

    Overview

    Accelerating NumPy Code with JAX

    Setting up Environment

    Loading Fashion-MNIST Dataset

    Preprocessing Data

    Building a Simple Neural Network with NumPy

    Defining Forward and Backward Passes

    Training Model

    Converting NumPy Code to JAX

    Handling Unsupported NumPy Functions in JAX

    Identifying Unsupported Functions

    Replacing ‘np.random.shuffle’

    Using JAX's ‘random.permutation’

    Replacing unsupported Linear Algebra Functions

    Handling In-Place Operations

    Dealing with Advanced Indexing and Assignment

    Handling Random Number Generation Functions

    Implementing Custom Functions

    Using Third-Party Libraries Compatible with JAX

    Refactoring Code to Align with JAX's Paradigm

    Managing Random Number Generation

    Understanding JAX's PRNG Keys

    Initializing PRNG Key

    Generating Random Numbers

    Shuffling Data without In-Place Operations

    Ensuring Reproducibility in Model Training

    Using PRNG Keys in Data Augmentation

    Managing PRNG Keys in Loops

    Incorporating PRNG Keys with JIT Compilation

    Storing and Restoring PRNG Keys

    Using ‘jax.random.fold_in’ for Multi-Process Environments

    Dealing with In-Place Operations

    Understanding In-Place Operations and JAX's Paradigm

    Refactored Code using Functional Updates

    Handling In-Place Array Modifications

    Refactoring Conditional In-Place Updates

    Avoiding In-Place Accumulations

    Using ‘jax.ops.index_update’ for Indexing Updates

    Refactoring Loops with In-Place Modifications

    Ensuring Numerical Stability During Transition

    Understanding JAX's Type Promotion Semantics

    Inspecting Data Types

    Converting Data Types Explicitly

    Initializing Model Parameters with Consistent Types

    Ensuring Consistent Data Types in Computations

    Handling Division and Logarithmic Operations

    Using JAX's Type Promotion Rules and Monitoring NaNs

    Implementing Gradient Clipping

    Summary

    Chapter 2: Profiling Computation and Device Memory

    Overview

    Measuring Execution Time of JAX Functions

    Identifying Functions to Profile

    Using JAX's Built-in Timing Function

    Profiling with ‘timeit’ and ‘cProfile’

    Using ‘line_profiler’ for Line-by-Line Analysis

    Profiling JIT-compiled Functions

    Monitoring GPU Memory Usage

    Profiling GPU Memory Usage

    Preventing JAX from Preallocating GPU Memory

    Reducing Data Precision

    Adjusting Batch Size

    Optimizing Data Loading

    Utilizing JAX's ‘jit’ Compilation

    Implementing Gradient Checkpointing

    Monitoring Memory Allocation with JAX Profiling Tools

    Minimizing Data Transfers between Host and Device

    Simplifying Model Architecture

    Utilizing Mixed Precision Training

    Visualizing Computation Graphs

    Prepare Sample Data and Parameters

    Convert JAX Functions to TensorFlow Functions

    Setup TensorBoard Logging

    Trace Computation Graph

    Launch TensorBoard

    Enhance Graph Readability with Named Scopes

    Optimizing Batch Sizes for Performance

    Understanding Impact of Batch Size

    Setting up Experiment

    Modifying Data Loader

    Implementing Timing and Memory Profiling

    Running Training Experiments with Different Batch Sizes

    Monitoring GPU Memory Usage

    Recording and Analyzing Results

    Batch Size

    Training Time (s)

    Peak GPU Memory Usage (MiB)

    Adjusting Learning Rate based on Batch Size

    Implementing Gradient Accumulation

    Reducing Memory Footprint with Gradient Checkpointing

    Understanding Forward Pass Structure

    Applying ‘jax.checkpoint’ to Forward Pass

    Adjusting Loss Function and Gradient Computation

    Training Model with Checkpointing

    Checkpointing Groups of Layers

    Using ‘jax.remat’ with Custom Policies

    Summary

    Chapter 3: Debugging Runtime Values and Errors

    Overview

    Handling Concretization Errors

    Understanding Concretization Errors

    Replace Python Control Flow with JAX Control Flow Primitives

    Use Element-wise Operations for Arrays

    Handle Loops with ‘jax.lax.scan’ or ‘jax.lax.fori_loop’

    Move Runtime-dependent Logic outside JIT Functions

    Use Static Arguments with ‘static_argnums’

    Inspecting Intermediate Values in JIT-Compiled Functions

    Understanding Why ‘print’ Statements Fail?

    Using ‘jax.debug.print’ for Debugging

    Applying ‘jax.debug.print’ in Model

    Conditional Debugging

    Using Host Callbacks with ‘jax.experimental.host_callback’

    Dealing with Shape Mismatch Errors

    Understanding Tensor Shapes in Model

    Verifying Parameter Shapes

    Implementing Forward Pass with Shape Checks

    Utilizing ‘jax.debug.print’ for Shape Inspection

    Checking Input Data Shapes before Training

    Handling Broadcasting Issues

    Using Explicit Reshaping and Transposing

    Testing with Small Batches

    Resolving Issues with NaNs in Computations

    Identifying Source of NaNs

    Checking for NaNs in Activations and Gradients

    Handling Numerical Instability in Softmax Function

    Applying Gradient Clipping

    Adding Regularization

    Proper Weight Initialization

    Using Alternative Activation Functions

    Implementing Batch Normalization

    Summary

    Chapter 4: Mastering Pytrees for Data Structures

    Overview

    Manipulating Nested Data with Pytrees

    Understanding Pytrees

    Using ‘jax.tree_map’ to apply Functions over Pytrees

    Working with Multiple Pytrees

    Flattening and Unflattening Pytrees

    Applying JAX Transformations with Pytrees

    Updating Parameters using Gradients

    Filtering Leaves based on Conditions

    Define Custom PyTree Nodes

    Custom Pytrees for User-defined Classes

    Define Custom Class

    Use JAX Functions

    Register Custom Class as a PyTree Node

    Integrate with JAX Transformations

    Compute Gradients

    Combine Multiple Layers into a Model

    Use ‘vmap’ and other Transformations

    Serializing and Deserializing Pytrees

    Understanding Model Parameters as Pytrees

    Serializing PyTree to Bytes

    Loading and Deserializing PyTree

    Handling Custom Classes in Pytrees

    Serializing and Deserializing PyTree with Custom Classes

    Saving and Loading Checkpoints during Training

    Using ‘flax.training.checkpoints’ Module

    Filtering Pytrees for Specific Parameters

    Defining Filter Function

    Traversing Pytree with ‘jax.tree_map’ and Filtering

    Applying Updates selectively based on Parameter Names

    Integrating with Gradient Computation

    Generalizing Filter Function

    Using Regular Expressions for Complex Filters

    Using ‘optax’ for Parameter Masks

    Combining Multiple Masks

    Summary

    Chapter 5: Exporting and Serialization

    Overview

    Saving Trained Models for Deployment

    Define and Train Model

    Save Model Parameters

    Save and Load Model for Inference

    Deploy Model

    Testing Deployment

    Checkpointing Long-Running Jobs

    Creating Training State

    Implementing Training Step

    Implementing Checkpointing

    Handling Unexpected Interruptions

    Customizing Checkpointing Strategies

    Testing Checkpoint Restoration

    Converting JAX Models to Other Frameworks

    Convert JAX Model to TensorFlow

    Build TensorFlow Function for Exporting

    Convert TensorFlow Model to ONNX

    Run Inference with ONNX Model

    Compare JAX and ONNX Model Results

    Serializing Custom Objects in JAX

    Define Custom Objects

    Initialize and Train Model

    Serialize the Model

    Implement Custom Serialization Functions

    Load Model Parameters

    Handle additional Custom Objects

    Serialize Training State

    Summary

    Chapter 6: Type Promotion Semantics and Mixed Precision

    Overview

    Controlling Type Promotion in Arithmetic Operations

    Understanding JAX's Type Promotion Rules

    Explicitly Cast Data Types before Operations

    Consistent Data Types in Model Parameters

    Casting Input Data and Enforcing Data Types

    Enjoying the preview?
    Page 1 of 1