View on ai.google.dev | Run in Google Colab | Open in Vertex AI | View source on GitHub |
This tutorial demonstrates how to perform basic sampling/inference with the RecurrentGemma 2B Instruct model using Google DeepMind's recurrentgemma
library that was written with JAX (a high-performance numerical computing library), Flax (the JAX-based neural network library), Orbax (a JAX-based library for training utilities like checkpointing), and SentencePiece (a tokenizer/detokenizer library). Although Flax is not used directly in this notebook, Flax was used to create Gemma and RecurrentGemma (the Griffin model).
This notebook can run on Google Colab with the T4 GPU (go to Edit > Notebook settings > Under Hardware accelerator select T4 GPU).
Setup
The following sections explain the steps for preparing a notebook to use a RecurrentGemma model, including model access, getting an API key, and configuring the notebook runtime
Set up Kaggle access for Gemma
To complete this tutorial, you first need to follow the setup instructions similar to Gemma setup with a few exceptions:
- Get access to RecurrentGemma (instead of Gemma) on kaggle.com.
- Select a Colab runtime with sufficient resources to run the RecurrentGemma model.
- Generate and configure a Kaggle username and API key.
After you've completed the RecurrentGemma setup, move on to the next section, where you'll set environment variables for your Colab environment.
Set environment variables
Set environment variables for KAGGLE_USERNAME
and KAGGLE_KEY
. When prompted with the "Grant access?" messages, agree to provide secret access.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Install the recurrentgemma
library
This notebook focuses on using a free Colab GPU. To enable hardware acceleration, click on Edit > Notebook settings > Select T4 GPU > Save.
Next, you need to install the Google DeepMind recurrentgemma
library from github.com/google-deepmind/recurrentgemma
. If you get an error about "pip's dependency resolver", you can usually ignore it.
pip install git+https://github.com/google-deepmind/recurrentgemma.git
Load and prepare the RecurrentGemma model
- Load the RecurrentGemma model with
kagglehub.model_download
, which takes three arguments:
handle
: The model handle from Kagglepath
: (Optional string) The local pathforce_download
: (Optional boolean) Forces to re-download the model
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- Check the location of the model weights and the tokenizer, then set the path variables. The tokenizer directory will be in the main directory where you downloaded the model, while the model weights will be in a sub-directory. For example:
- The
tokenizer.model
file will be in/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
). - The model checkpoint will be in
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Perform sampling/inference
- Load the RecurrentGemma model checkpoint with the
recurrentgemma.jax.load_parameters
method. Thesharding
argument set to"single_device"
loads all model parameters on a single device.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- Load the RecurrentGemma model tokenizer, constructed using
sentencepiece.SentencePieceProcessor
:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- To automatically load the correct configuration from the RecurrentGemma model checkpoint, use
recurrentgemma.GriffinConfig.from_flax_params_or_variables
. Then, instantiate the Griffin model withrecurrentgemma.jax.Griffin
.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- Create a
sampler
withrecurrentgemma.jax.Sampler
on top of the RecurrentGemma model checkpoint/weights and the tokenizer:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- Write a prompt in
prompt
and perform inference. You can tweaktotal_generation_steps
(the number of steps performed when generating a response — this example uses50
to preserve host memory).
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
Learn more
- You can learn more about the Google DeepMind
recurrentgemma
library on GitHub, which contains docstrings of methods and modules you used in this tutorial, such asrecurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
, andrecurrentgemma.jax.Sampler
. - The following libraries have their own documentation sites: core JAX, Flax, and Orbax.
- For
sentencepiece
tokenizer/detokenizer documentation, check out Google'ssentencepiece
GitHub repo. - For
kagglehub
documentation, check outREADME.md
on Kaggle'skagglehub
GitHub repo. - Learn how to use Gemma models with Google Cloud Vertex AI.
- Check out the RecurrentGemma: Moving Past Transformers for Efficient Open Language Models paper by Google DeepMind.
- Read the Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models paper by GoogleDeepMind to learn more about the model architecture used by RecurrentGemma.