docs: add canonical tooling corpus (147 files) from Google/HF/frameworks

Five-lane parallel research pass. Each subdir under tooling/ has its own
README indexing downloaded files with verified upstream sources.

- google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts,
  gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev
  HTML snapshots, Gemma 3 tech report
- huggingface/: 8 gemma-4-* model cards, chat-template .jinja files,
  tokenizer_config.json, transformers gemma4/ source, launch blog posts,
  official HF Spaces app.py
- inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI
  comparison, run_commands.sh with 8 working launches, 9 code snippets
- gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2,
  Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma)
- fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE),
  TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md

Findings that update earlier CORPUS_* docs are flagged in tooling/README.md
(not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch
abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM,
FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech
report PDF yet, no Gemma-4-generation specialized siblings yet.

Pre-commit secrets hook bypassed per user authorization — flagged "secrets"
are base64 notebook cell outputs and example Ed25519 keys in the HDP
agentic-security demo, not real credentials.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Mortdecai
2026-04-18 12:24:48 -04:00
parent 5011059f5d
commit eecebe7ef5
149 changed files with 181297 additions and 0 deletions
@@ -0,0 +1,99 @@
# Gemma
[![Unittests](https://github.com/google-deepmind/gemma/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/google-deepmind/gemma/actions/workflows/pytest_and_autopublish.yml)
[![PyPI version](https://badge.fury.io/py/gemma.svg)](https://badge.fury.io/py/gemma)
[![Documentation Status](https://readthedocs.org/projects/gemma-llm/badge/?version=latest)](https://gemma-llm.readthedocs.io/en/latest/?badge=latest)
[Gemma](https://ai.google.dev/gemma) is a family of open-weights Large Language
Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini
research and technology.
This repository contains the implementation of the
[`gemma`](https://pypi.org/project/gemma/) PyPI package. A
[JAX](https://github.com/jax-ml/jax) library to use and fine-tune Gemma.
For examples and use cases, see our
[documentation](https://gemma-llm.readthedocs.io/). Please
report issues and feedback in
[our GitHub](https://github.com/google-deepmind/gemma/issues).
### Installation
1. Install JAX for CPU, GPU or TPU. Follow the instructions on
[the JAX website](https://jax.readthedocs.io/en/latest/installation.html).
1. Run
```sh
pip install gemma
```
### Examples
Here is a minimal example to have a multi-turn, multi-modal conversation with
Gemma:
```python
from gemma import gm
# Model and parameters (Gemma 4)
model = gm.nn.Gemma4_E4B()
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA4_E4B_IT)
# Example of multi-turn conversation
sampler = gm.text.ChatSampler(
model=model,
params=params,
multi_turn=True,
)
prompt = """Which of the 2 images do you prefer ?
Image 1: <|image|>
Image 2: <|image|>
Write your answer as a poem."""
out0 = sampler.chat(prompt, images=[image1, image2])
out1 = sampler.chat('What about the other image ?')
```
The same `ChatSampler` API works with all Gemma versions (2, 3, 3n, 4).
Our documentation contains various Colabs and tutorials, including:
* [Sampling](https://gemma-llm.readthedocs.io/en/latest/colab_sampling.html)
* [Multi-modal](https://gemma-llm.readthedocs.io/en/latest/colab_multimodal.html)
* [Fine-tuning](https://gemma-llm.readthedocs.io/en/latest/colab_finetuning.html)
* [LoRA](https://gemma-llm.readthedocs.io/en/latest/colab_lora_sampling.html)
* ...
Additionally, our
[examples/](https://github.com/google-deepmind/gemma/tree/main/examples) folder
contain additional scripts to fine-tune and sample with Gemma.
### Learn more about Gemma
* To use this library: [Gemma documentation](https://gemma-llm.readthedocs.io/)
* Technical reports for metrics and model capabilities:
* [Gemma 1](https://goo.gle/GemmaReport)
* [Gemma 2](https://goo.gle/gemma2report)
* [Gemma 3](https://storage.googleapis.com/deepmind-media/gemma/Gemma3Report.pdf)
* Gemma 4 (Coming soon)
* Other Gemma implementations and doc on the
[Gemma ecosystem](https://ai.google.dev/gemma/docs)
### Downloading the models
To download the model weights. See
[our documentation](https://gemma-llm.readthedocs.io/en/latest/checkpoints.html).
### System Requirements
Gemma can run on a CPU, GPU and TPU. For GPU, we recommend 8GB+ RAM on GPU for
The 2B checkpoint and 24GB+ RAM on GPU are used for the 7B checkpoint.
### Contributing
We welcome contributions! Please read our [Contributing Guidelines](./CONTRIBUTING.md) before submitting a pull request.
*This is not an official Google product.*
File diff suppressed because one or more lines are too long
@@ -0,0 +1,568 @@
{
"cells": [
{
"metadata": {
"id": "-KkvqLgjiIdD"
},
"cell_type": "markdown",
"source": [
"# Tool Use\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/tool_use.ipynb)\n",
"\n",
"Demo to show how to use tool-use with Gemma library.\n",
"\n",
"Note: The Gemma 1, 2 and 3 models were not specifically trained for tool use. This is more a proof-of-concept than an officially supported feature."
]
},
{
"metadata": {
"id": "gcNRfVEnj4aq"
},
"cell_type": "code",
"source": [
"!pip install -q gemma"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {
"executionInfo": {
"elapsed": 2221,
"status": "ok",
"timestamp": 1749202985345,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -120
},
"id": "k1ZAgLg1j9NT"
},
"cell_type": "code",
"source": [
"# Common imports\n",
"import os\n",
"import datetime\n",
"\n",
"# Gemma imports\n",
"from gemma import gm"
],
"outputs": [],
"execution_count": 3
},
{
"metadata": {
"id": "139lZszJj_CC"
},
"cell_type": "markdown",
"source": [
"By default, Jax does not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):"
]
},
{
"metadata": {
"executionInfo": {
"elapsed": 2,
"status": "ok",
"timestamp": 1749138071985,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -120
},
"id": "VtlWWLIYj_LJ"
},
"cell_type": "code",
"source": [
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
],
"outputs": [],
"execution_count": 2
},
{
"metadata": {
"id": "31JPZb5RkD_p"
},
"cell_type": "markdown",
"source": [
"Load the model and the params."
]
},
{
"metadata": {
"executionInfo": {
"elapsed": 39057,
"status": "ok",
"timestamp": 1749203024713,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -120
},
"id": "RsAo6k4_kEJS",
"outputId": "e10afb5c-6c81-42e8-e590-a39ea4ef3bf7"
},
"cell_type": "code",
"source": [
"model = gm.nn.Gemma3_4B()\n",
"\n",
"params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)"
],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:2025-06-06 02:43:16,896:jax._src.xla_bridge:749: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'\n",
"INFO:2025-06-06 02:43:16,897:jax._src.xla_bridge:749: Unable to initialize backend 'proxy': INVALID_ARGUMENT: IFRT proxy server address must be '<transport-type>://<backend-address>' (e.g., 'grpc://localhost'), but got \n",
"INFO:2025-06-06 02:43:16,900:jax._src.xla_bridge:749: Unable to initialize backend 'mlcr': Could not initialize backend 'mlcr'\n",
"INFO:2025-06-06 02:43:16,901:jax._src.xla_bridge:749: Unable to initialize backend 'sliceme': Could not initialize backend 'sliceme'\n"
]
}
],
"execution_count": 4
},
{
"metadata": {
"id": "p108c5yIlYH7"
},
"cell_type": "markdown",
"source": [
"## Using existing tools\n",
"\n",
"If you're familiar with the [sampling](https://gemma-llm.readthedocs.io/en/latest/sampling.html) tutorial, using tool-use differ in two ways:\n",
"\n",
"1. Using the `gm.text.ToolSampler` rather than the `gm.text.ChatSampler`.\n",
"2. Passing the `tools=` you want to use to the sampler.\n",
"\n",
"For example:"
]
},
{
"metadata": {
"colab": {
"height": 594
},
"executionInfo": {
"elapsed": 50615,
"status": "ok",
"timestamp": 1749138791069,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -120
},
"id": "iRCV5h8BlVX6",
"outputId": "b3b5d83d-8a8b-4982-fc8f-d409fb8b38a9"
},
"cell_type": "code",
"source": [
"sampler = gm.text.ToolSampler(\n",
" model=model,\n",
" params=params,\n",
" tools=[\n",
" gm.tools.Calculator(),\n",
" gm.tools.FileExplorer(),\n",
" ],\n",
" print_stream=True,\n",
")\n",
"\n",
"output = sampler.chat('I have a serie `Sn+1 = cos(Sn) * 2`. Using the calculator, compute the steps 0-4 for S0 = 3')"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thought: I need to compute S0, S1, S2, S3, and S4 using the given recursive formula Sn+1 = cos(Sn) * 2 and S0 = 3.\n",
"Let's start with S0 = 3.\n",
"S1 = cos(S0) * 2 = cos(3) * 2\n",
"S2 = cos(S1) * 2 = cos(cos(3) * 2) * 2\n",
"S3 = cos(S2) * 2 = cos(cos(cos(3) * 2) * 2) * 2\n",
"S4 = cos(S3) * 2 = cos(cos(cos(cos(3) * 2) * 2)) * 2\n",
"\n",
"I will use the calculator to compute these values.\n",
"{\"tool_name\": \"calculator\", \"expression\": \"cos(3) * 2\"}\n",
"\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Tool result: -1.9799849932008908]\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thought: Now I need to compute S1 = cos(S0) * 2 = cos(3) * 2 = -1.9799849932008908 * 2\n",
"{\"tool_name\": \"calculator\", \"expression\": \"-1.9799849932008908 * 2\"}\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Tool result: -3.9599699864017817]\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thought: Now I need to compute S2 = cos(S1) * 2 = cos(-3.9599699864017817) * 2\n",
"{\"tool_name\": \"calculator\", \"expression\": \"cos(-3.9599699864017817) * 2\"}\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Tool result: -1.3668134299076982]\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thought: Now I need to compute S3 = cos(S2) * 2 = cos(-1.3668134299076982) * 2\n",
"{\"tool_name\": \"calculator\", \"expression\": \"cos(-1.3668134299076982) * 2\"}\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Tool result: 0.4051424976130353]\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thought: Now I need to compute S4 = cos(S3) * 2 = cos(0.4051424976130353) * 2\n",
"{\"tool_name\": \"calculator\", \"expression\": \"cos(0.4051424976130353) * 2\"}\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Tool result: 1.8380924822033438]\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The values are: S0 = 3, S1 = -3.9599699864017817, S2 = -1.3668134299076982, S3 = 0.4051424976130353, S4 = 1.8380924822033438"
]
}
],
"execution_count": 10
},
{
"metadata": {
"id": "FAI54F-Blkan"
},
"cell_type": "markdown",
"source": [
"Note: Only the final model answer is returned. You can access the conversation history, including all intermediates tool calls and output through `sampler.turns` property."
]
},
{
"metadata": {
"id": "D0_IIS1Nlfuw"
},
"cell_type": "markdown",
"source": [
"## Creating your own tool\n",
"\n",
"To create your own tool, you can inherit from the `gm.tools.Tool` class. You should provide:\n",
"\n",
"* A description & example, so the model knows how to use your tool\n",
"* Implement the `call` method. The `call` function can take arbitrary `**kwargs`, but the name of the args should match the ones defined in `tool_kwargs` and `tool_kwargs_doc`"
]
},
{
"metadata": {
"executionInfo": {
"elapsed": 55,
"status": "ok",
"timestamp": 1749203934196,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -120
},
"id": "XqmQcfdI0oEl"
},
"cell_type": "code",
"source": [
"class DateTime(gm.tools.Tool):\n",
" \"\"\"Tool to access the current date.\"\"\"\n",
"\n",
" DESCRIPTION = 'Access the current date, time,...'\n",
" EXAMPLE = gm.tools.Example(\n",
" query='Which day of the week are we today ?',\n",
" thought='The `datetime.strptime` uses %a for day of the week',\n",
" tool_kwargs={'format': '%a'},\n",
" tool_kwargs_doc={'format': '<ANY datetime.strptime expression>'},\n",
" result='Sat',\n",
" answer='Today is Saturday.',\n",
" )\n",
"\n",
" def call(self, format: str) -> str:\n",
" dt = datetime.datetime.now()\n",
" return dt.strftime(format)\n"
],
"outputs": [],
"execution_count": 7
},
{
"metadata": {
"id": "sSxYhXPuuXYp"
},
"cell_type": "markdown",
"source": [
"The tool can then be used in the sampler:"
]
},
{
"metadata": {
"colab": {
"height": 118
},
"executionInfo": {
"elapsed": 2156,
"status": "ok",
"timestamp": 1749204833094,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -120
},
"id": "9S8xB2B-0cbW",
"outputId": "fccc0e89-e922-4184-8b77-800041cdd77e"
},
"cell_type": "code",
"source": [
"sampler = gm.text.ToolSampler(\n",
" model=model,\n",
" params=params,\n",
" tools=[\n",
" DateTime(),\n",
" ],\n",
" print_stream=True,\n",
")\n",
"\n",
"output = sampler.chat('Which date are we today ?')"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thought: I need to get the current date.\n",
"{\"tool_name\": \"datetime\", \"format\": \"%Y-%m-%d\"}\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Tool result: 2025-06-06]\n"
]
},
{
"data": {
"text/html": [
"<hr>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Today is June 6th, 2025."
]
}
],
"execution_count": 9
},
{
"metadata": {
"id": "esIpCjhxzHmf"
},
"cell_type": "markdown",
"source": [
"## Next steps\n",
"\n",
"* See our [multimodal](https://gemma-llm.readthedocs.io/en/latest/multimodal.html) example to query the model with images.\n",
"* See our [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) example to train Gemma on your custom task.\n"
]
}
],
"metadata": {
"colab": {
"last_runtime": {},
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
@@ -0,0 +1,130 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Example config for finetuning Gemma for a classification task.
* Input: A text to classify.
* Output: A classification label. The pre-trained Gemma model is trained to
predict one world among 256.000. Here, we're finetuning to predict only 2
tokens among the 256.000 available.
Train locally with:
```sh
python -m kauldron.main \
--cfg=examples/classification.py \
--cfg.workdir=/tmp/kauldron_oss/workdir
```
"""
from kauldron import konfig
# pylint: disable=g-import-not-at-top
with konfig.imports():
from gemma import gm
from kauldron import kd
import optax
# pylint: enable=g-import-not-at-top
def get_config():
"""Get the default hyperparameter configuration."""
return kd.train.Trainer(
seed=42,
# Dataset
train_ds=_make_dataset(training=True),
# Model definition
model=gm.nn.Gemma3_4B(
tokens="batch.sentence",
return_last_only=True,
),
# Load the weights from the pretrained checkpoint
init_transform=gm.ckpts.LoadCheckpoint(
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
),
# Training
num_train_steps=10_000,
train_losses={
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
logits="preds.logits",
labels="batch.label",
),
},
optimizer=optax.adafactor(learning_rate=1e-4),
checkpointer=kd.ckpts.Checkpointer(
save_interval_steps=500,
),
# Evaluation
evals={
"test": kd.evals.Evaluator(
run=kd.evals.EveryNSteps(1000),
ds=_make_dataset(training=False),
),
},
)
def _make_dataset(training: bool) -> kd.data.Pipeline:
# Dict key names from the dataset
_INPUT_FIELD = "sentence" # pylint: disable=invalid-name
_LABEL_FIELD = "label" # pylint: disable=invalid-name
tokenizer = gm.text.Gemma3Tokenizer()
return kd.data.py.Tfds(
name="glue/cola",
split="train" if training else "validation",
shuffle=True if training else False,
num_epochs=None if training else 1,
batch_size=8,
transforms=[
# Process the input text
# TFDS datasets returns `bytes`, so convert them to `str`
gm.data.DecodeBytes(key=_INPUT_FIELD),
gm.data.FormatText(
key=_INPUT_FIELD,
template="""<start_of_turn>user
Please classify whether the following sentence is grammaticaly correct, please answer only with Yes or No.
Sentence: {text}<end_of_turn>
<start_of_turn>model""",
),
gm.data.Tokenize(
key=_INPUT_FIELD,
tokenizer=tokenizer,
add_bos=True,
),
gm.data.Pad(
key=_INPUT_FIELD,
max_length=128,
),
# Process the label
gm.data.MapInts(
key=_LABEL_FIELD,
# Rather than predicting the token 0 and 1, we are using the
# token 1294 and 3553 which respectivelly correspond to "No" and
# "Yes". We do this because those token already contain semantic
# information, so even zero-shot prediction without any
# finetuning has better than random performances.
old_to_new={
0: 1294, # Token -> "No"
1: 3553, # Token -> "Yes"
},
),
kd.data.Rearrange(
key=_LABEL_FIELD,
pattern="... -> ... 1", # For shape compatibility with the loss.
),
],
)
@@ -0,0 +1,122 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""DPO Example.
DPO works by running two answers (one prefered and one rejected) into both
the reference model and the model to finetune. Then the DPO loss is used to
increase the likelihood of generating the preferred answer.
Implementation wise, this is done by:
* Wrapping the model inside a `gm.nn.AnchoredPolicy` (which runs both the
model and the reference frozen model)
* Using the `gm.ckpts.AnchoredPolicyLoader` to restore the weights, so the
weights are correctly mapped to inside `gm.nn.AnchoredPolicy`.
Train locally with:
```sh
python -m kauldron.main \
--cfg=examples/dpo.py \
--cfg.workdir=/tmp/kauldron_oss/workdir
```
"""
from kauldron import konfig
# pylint: disable=g-import-not-at-top
with konfig.imports():
from gemma import gm
from kauldron import kd
import optax
# pylint: enable=g-import-not-at-top
def get_config():
"""Get the default hyperparameter configuration."""
return kd.train.Trainer(
seed=42,
# Dataset
train_ds=_make_dataset(training=True),
# Model definition
model=gm.nn.AnchoredPolicy(
policy=gm.nn.Gemma3_4B(tokens="batch.tokens", text_only=True),
),
# Load the weights from the pretrained checkpoint
init_transform=gm.ckpts.AnchoredPolicyLoader(
policy=gm.ckpts.LoadCheckpoint(
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
),
),
# Training
num_train_steps=10_000,
train_losses={
"dpo": gm.losses.DpoLoss(
tokens="batch.targets",
sequence_mask="batch.mask",
policy_logits="preds.policy.logits",
anchor_logits="preds.anchor.logits",
),
},
optimizer=optax.adafactor(learning_rate=1e-4),
checkpointer=kd.ckpts.Checkpointer(
save_interval_steps=500,
),
# Evaluation
evals={
# "test": kd.evals.Evaluator(
# run=kd.evals.EveryNSteps(1000),
# ds=_make_dataset(training=False),
# ),
},
)
def _make_dataset(training: bool) -> kd.data.Pipeline:
# TODO(epot): !!!!
max_length = 512
batch_size = 16
tokenizer = gm.text.Gemma3Tokenizer()
return kd.data.py.HuggingFace(
path="argilla/distilabel-math-preference-dpo",
split="train",
shuffle=True if training else False,
num_epochs=None if training else 1,
batch_size=batch_size,
transforms=[
# Only keep the fields we need.
kd.data.Elements(
keep=["instruction", "chosen_response", "rejected_response"]
),
# Create the model inputs and loss mask.
gm.data.ContrastiveTask(
in_prompt="instruction",
in_chosen="chosen_response",
in_rejected="rejected_response",
out_tokens="tokens",
out_targets="targets",
out_mask="mask",
tokenizer=tokenizer,
# Padding parameters
max_length=max_length,
# TODO(epot): Run stats (how many examples are we dropping?)
truncate=True,
),
],
)
@@ -0,0 +1,154 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Example of Gemma finetuning using LoRA.
This example is based on the `seq2seq.py` example. See the
docstring of that file for more details.
The changes to use LoRA are:
* `model`: Use `gm.nn.LoRA()` wrapper to add `LoRA` adapters to the
model.
* `init_transform`: Use `gm.ckpts.SkipLoRA()` wrapper to only restore the
non-LoRA weights.
* `optimizer`: Use `kd.optim.partial_updates` wrapper to only train the LoRA
weights.
Train locally with:
```sh
python -m kauldron.main \
--cfg=examples/lora.py \
--cfg.workdir=/tmp/kauldron_oss/workdir
```
"""
from kauldron import konfig
# pylint: disable=g-import-not-at-top
with konfig.imports():
from gemma import gm
from kauldron import kd
import optax
# pylint: enable=g-import-not-at-top
def get_config():
batch_size = 16
max_length = 512
return kd.train.Trainer(
seed=42,
# Dataset
train_ds=_make_dataset(
training=True,
batch_size=batch_size,
max_length=max_length,
),
# Model definition
model=gm.nn.LoRA(
rank=4,
model=gm.nn.Gemma3_4B(
tokens="batch.input",
# TODO(epot): At the moment, LoRA fine-tuning with multimodal
# is not supported. Willbe fixed soon.
text_only=True,
),
),
# Load the weights from the pretrained checkpoint
# Use `SkipLoRA` as the original checkpoint does not contain the LoRA
# weights.
init_transform=gm.ckpts.SkipLoRA(
wrapped=gm.ckpts.LoadCheckpoint(
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
)
),
# Training
num_train_steps=10_000,
train_losses={
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
logits="preds.logits",
labels="batch.target",
mask="batch.loss_mask",
),
},
# TODO(epot): Add Gradient accumenlation.
optimizer=kd.optim.partial_updates(
optax.adafactor(learning_rate=0.005),
# We only optimize the LoRA weights. The rest of the model is frozen.
mask=kd.optim.select("lora"),
),
checkpointer=kd.ckpts.Checkpointer(
save_interval_steps=500,
),
# Evaluation
evals={
"test": kd.evals.Evaluator(
run=kd.evals.EveryNSteps(1000),
ds=_make_dataset(
training=False,
batch_size=batch_size,
max_length=max_length,
),
),
# The sampler evaluator run inference on a few prompts from the
# test set.
"sampling": gm.evals.SamplerEvaluator(
run=kd.evals.EveryNSteps(1000),
max_new_tokens=150, # Sampling parameters
num_batches=1, # Only predict a single example (batch_size=None)
ds=_make_dataset(training=False, sampling=True),
),
},
)
def _make_dataset(
*,
training: bool,
sampling: bool = False,
batch_size: int | None = None,
max_length: int | None = None,
):
tokenizer = gm.text.Gemma3Tokenizer()
return kd.data.py.Tfds(
name="mtnt/en-fr",
split="train" if training else "test",
shuffle=True if training else False,
num_epochs=None if training else 1,
batch_size=None if sampling else batch_size,
num_workers=4,
transforms=[
# Create the model inputs/targets/loss_mask.
gm.data.Seq2SeqTask(
# Select which field from the dataset to use.
# https://www.tensorflow.org/datasets/catalog/mtnt
in_prompt="src",
in_response="dst",
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
out_input="input",
out_target="target",
out_target_mask="loss_mask",
tokenizer=tokenizer,
# Padding parameters
max_length=None if sampling else max_length,
# In this dataset, ~1% of examples are longer than 512 tokens.
truncate=True,
sampling=sampling,
),
],
)
@@ -0,0 +1,164 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Example of Gemma finetuning for an image captioning task.
Example:
Prompt:
```
<start_of_turn>user
<start_of_image><end_of_turn>
<start_of_turn>model
```
Target:
```
A diagram showing a circuit with a battery, lamp, and switch.<end_of_turn>
```
Here, the prompt only contains the `<start_of_image>` to indicate an image
is inserted.
Train locally with:
```sh
python -m kauldron.main \
--cfg=examples/multimodal.py \
--cfg.workdir=/tmp/kauldron_oss/workdir
```
"""
from kauldron import konfig
# pylint: disable=g-import-not-at-top
with konfig.imports():
import jax.numpy as jnp
from gemma import gm
from kauldron import kd
import optax
# pylint: enable=g-import-not-at-top
def get_config():
batch_size = 32
max_length = 200
return kd.train.Trainer(
seed=42,
# Dataset
train_ds=_make_dataset(
training=True,
batch_size=batch_size,
max_length=max_length,
),
# Model definition
model=gm.nn.Gemma3_4B(
tokens="batch.input",
images="batch.image",
),
# Load the weights from the pretrained checkpoint
init_transform=gm.ckpts.LoadCheckpoint(
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
),
# Training
num_train_steps=10_000,
train_losses={
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
logits="preds.logits",
labels="batch.target",
mask="batch.loss_mask",
),
},
train_summaries={
"image": kd.summaries.ShowImages(images="batch.image", num_images=5),
},
optimizer=optax.adafactor(learning_rate=1e-3),
checkpointer=kd.ckpts.Checkpointer(
save_interval_steps=500,
),
# Evaluation
evals={
"test": kd.evals.Evaluator(
run=kd.evals.EveryNSteps(1000),
ds=_make_dataset(
training=False,
batch_size=4,
max_length=max_length,
),
),
# The sampler evaluator run inference on a few prompts from the
# test set.
"sampling": gm.evals.SamplerEvaluator(
run=kd.evals.EveryNSteps(1000),
max_new_tokens=50, # Sampling parameters
num_batches=3,
ds=_make_dataset(training=False, sampling=True),
summaries={
"image": kd.summaries.ShowImages(
images="batch.image", num_images=5
),
},
),
},
)
def _make_dataset(
*,
training: bool,
sampling: bool = False,
batch_size: int | None = None,
max_length: int | None = None,
):
tokenizer = gm.text.Gemma3Tokenizer()
return kd.data.py.Tfds(
name="ai2dcaption",
split="llava_15" if training else "test",
shuffle=True if training else False,
num_epochs=None if training else 1,
batch_size=None if sampling else batch_size,
num_workers=4,
transforms=[
# Only keep the fields we need.See fields at:
# https://www.tensorflow.org/datasets/catalog/ai2dcaption
kd.data.Elements(keep=["image", "caption"]),
# Create a new constant field
kd.data.AddConstants({"prompt": "<start_of_image>"}),
# Create the model inputs/targets/loss_mask.
gm.data.Seq2SeqTask(
# Select which field from the dataset to use.
in_prompt="prompt",
in_response="caption",
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
out_input="input",
out_target="target",
out_target_mask="loss_mask",
tokenizer=tokenizer,
# Padding parameters
max_length=None if sampling else max_length,
# In this dataset, ~1% of examples are longer than 512 tokens.
truncate=True,
sampling=sampling,
),
kd.data.py.Resize(key="image", size=(800, 800)),
# TODO(epot): Make the `num_images` dimension optional
kd.data.Rearrange(key="image", pattern="... h w c -> ... 1 h w c"),
kd.data.Cast(key="image", dtype=jnp.uint8),
],
)
@@ -0,0 +1,133 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Example of Gemma finetuning for a prompt -> response task.
This is a fork of the seq2seq example, but with sharding.
The only difference is the `sharding=kd.sharding.ShardingStrategy()`
Train locally with:
```sh
python -m kauldron.main \
--cfg=examples/sharding.py \
--cfg.workdir=/tmp/kauldron_oss/workdir
```
"""
from kauldron import konfig
# pylint: disable=g-import-not-at-top
with konfig.imports():
from gemma import gm
from kauldron import kd
import optax
# pylint: enable=g-import-not-at-top
def get_config():
batch_size = 16
max_length = 512
return kd.train.Trainer(
seed=42,
# Dataset
train_ds=_make_dataset(
training=True,
batch_size=batch_size,
max_length=max_length,
),
# Model definition
model=gm.nn.Gemma3_4B(
tokens="batch.input",
),
sharding=kd.sharding.ShardingStrategy(
params=kd.sharding.FSDPSharding(),
),
# Load the weights from the pretrained checkpoint
init_transform=gm.ckpts.LoadCheckpoint(
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
),
# Training
num_train_steps=10_000,
train_losses={
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
logits="preds.logits",
labels="batch.target",
mask="batch.loss_mask",
),
},
optimizer=optax.adafactor(learning_rate=1e-3),
checkpointer=kd.ckpts.Checkpointer(
save_interval_steps=500,
),
# Evaluation
evals={
"test": kd.evals.Evaluator(
run=kd.evals.EveryNSteps(1000),
ds=_make_dataset(
training=False,
batch_size=batch_size,
max_length=max_length,
),
),
# The sampler evaluator run inference on a few prompts from the
# test set.
"sampling": gm.evals.SamplerEvaluator(
run=kd.evals.EveryNSteps(1000),
max_new_tokens=50, # Sampling parameters
num_batches=1, # Only predict a single example (batch_size=None)
ds=_make_dataset(training=False, sampling=True),
),
},
)
def _make_dataset(
*,
training: bool,
sampling: bool = False,
batch_size: int | None = None,
max_length: int | None = None,
):
tokenizer = gm.text.Gemma3Tokenizer()
return kd.data.py.Tfds(
name="mtnt/en-fr",
split="train" if training else "test",
shuffle=True if training else False,
num_epochs=None if training else 1,
batch_size=None if sampling else batch_size,
num_workers=4,
transforms=[
# Create the model inputs/targets/loss_mask.
gm.data.Seq2SeqTask(
# Select which field from the dataset to use.
# https://www.tensorflow.org/datasets/catalog/mtnt
in_prompt="src",
in_response="dst",
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
out_input="input",
out_target="target",
out_target_mask="loss_mask",
tokenizer=tokenizer,
# Padding parameters
max_length=None if sampling else max_length,
# In this dataset, ~1% of examples are longer than 512 tokens.
truncate=True,
sampling=sampling,
),
],
)