docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,99 @@
|
||||
# Gemma
|
||||
|
||||
[](https://github.com/google-deepmind/gemma/actions/workflows/pytest_and_autopublish.yml)
|
||||
[](https://badge.fury.io/py/gemma)
|
||||
[](https://gemma-llm.readthedocs.io/en/latest/?badge=latest)
|
||||
|
||||
[Gemma](https://ai.google.dev/gemma) is a family of open-weights Large Language
|
||||
Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini
|
||||
research and technology.
|
||||
|
||||
This repository contains the implementation of the
|
||||
[`gemma`](https://pypi.org/project/gemma/) PyPI package. A
|
||||
[JAX](https://github.com/jax-ml/jax) library to use and fine-tune Gemma.
|
||||
|
||||
For examples and use cases, see our
|
||||
[documentation](https://gemma-llm.readthedocs.io/). Please
|
||||
report issues and feedback in
|
||||
[our GitHub](https://github.com/google-deepmind/gemma/issues).
|
||||
|
||||
### Installation
|
||||
|
||||
1. Install JAX for CPU, GPU or TPU. Follow the instructions on
|
||||
[the JAX website](https://jax.readthedocs.io/en/latest/installation.html).
|
||||
1. Run
|
||||
|
||||
```sh
|
||||
pip install gemma
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
Here is a minimal example to have a multi-turn, multi-modal conversation with
|
||||
Gemma:
|
||||
|
||||
```python
|
||||
from gemma import gm
|
||||
|
||||
# Model and parameters (Gemma 4)
|
||||
model = gm.nn.Gemma4_E4B()
|
||||
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA4_E4B_IT)
|
||||
|
||||
# Example of multi-turn conversation
|
||||
sampler = gm.text.ChatSampler(
|
||||
model=model,
|
||||
params=params,
|
||||
multi_turn=True,
|
||||
)
|
||||
|
||||
prompt = """Which of the 2 images do you prefer ?
|
||||
|
||||
Image 1: <|image|>
|
||||
Image 2: <|image|>
|
||||
|
||||
Write your answer as a poem."""
|
||||
out0 = sampler.chat(prompt, images=[image1, image2])
|
||||
|
||||
out1 = sampler.chat('What about the other image ?')
|
||||
```
|
||||
|
||||
The same `ChatSampler` API works with all Gemma versions (2, 3, 3n, 4).
|
||||
|
||||
Our documentation contains various Colabs and tutorials, including:
|
||||
|
||||
* [Sampling](https://gemma-llm.readthedocs.io/en/latest/colab_sampling.html)
|
||||
* [Multi-modal](https://gemma-llm.readthedocs.io/en/latest/colab_multimodal.html)
|
||||
* [Fine-tuning](https://gemma-llm.readthedocs.io/en/latest/colab_finetuning.html)
|
||||
* [LoRA](https://gemma-llm.readthedocs.io/en/latest/colab_lora_sampling.html)
|
||||
* ...
|
||||
|
||||
Additionally, our
|
||||
[examples/](https://github.com/google-deepmind/gemma/tree/main/examples) folder
|
||||
contain additional scripts to fine-tune and sample with Gemma.
|
||||
|
||||
### Learn more about Gemma
|
||||
|
||||
* To use this library: [Gemma documentation](https://gemma-llm.readthedocs.io/)
|
||||
* Technical reports for metrics and model capabilities:
|
||||
* [Gemma 1](https://goo.gle/GemmaReport)
|
||||
* [Gemma 2](https://goo.gle/gemma2report)
|
||||
* [Gemma 3](https://storage.googleapis.com/deepmind-media/gemma/Gemma3Report.pdf)
|
||||
* Gemma 4 (Coming soon)
|
||||
* Other Gemma implementations and doc on the
|
||||
[Gemma ecosystem](https://ai.google.dev/gemma/docs)
|
||||
|
||||
### Downloading the models
|
||||
|
||||
To download the model weights. See
|
||||
[our documentation](https://gemma-llm.readthedocs.io/en/latest/checkpoints.html).
|
||||
|
||||
### System Requirements
|
||||
|
||||
Gemma can run on a CPU, GPU and TPU. For GPU, we recommend 8GB+ RAM on GPU for
|
||||
The 2B checkpoint and 24GB+ RAM on GPU are used for the 7B checkpoint.
|
||||
|
||||
### Contributing
|
||||
|
||||
We welcome contributions! Please read our [Contributing Guidelines](./CONTRIBUTING.md) before submitting a pull request.
|
||||
|
||||
*This is not an official Google product.*
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,568 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"metadata": {
|
||||
"id": "-KkvqLgjiIdD"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Tool Use\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/tool_use.ipynb)\n",
|
||||
"\n",
|
||||
"Demo to show how to use tool-use with Gemma library.\n",
|
||||
"\n",
|
||||
"Note: The Gemma 1, 2 and 3 models were not specifically trained for tool use. This is more a proof-of-concept than an officially supported feature."
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "gcNRfVEnj4aq"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install -q gemma"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"executionInfo": {
|
||||
"elapsed": 2221,
|
||||
"status": "ok",
|
||||
"timestamp": 1749202985345,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -120
|
||||
},
|
||||
"id": "k1ZAgLg1j9NT"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Common imports\n",
|
||||
"import os\n",
|
||||
"import datetime\n",
|
||||
"\n",
|
||||
"# Gemma imports\n",
|
||||
"from gemma import gm"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 3
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "139lZszJj_CC"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"By default, Jax does not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"executionInfo": {
|
||||
"elapsed": 2,
|
||||
"status": "ok",
|
||||
"timestamp": 1749138071985,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -120
|
||||
},
|
||||
"id": "VtlWWLIYj_LJ"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 2
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "31JPZb5RkD_p"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Load the model and the params."
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"executionInfo": {
|
||||
"elapsed": 39057,
|
||||
"status": "ok",
|
||||
"timestamp": 1749203024713,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -120
|
||||
},
|
||||
"id": "RsAo6k4_kEJS",
|
||||
"outputId": "e10afb5c-6c81-42e8-e590-a39ea4ef3bf7"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"model = gm.nn.Gemma3_4B()\n",
|
||||
"\n",
|
||||
"params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:2025-06-06 02:43:16,896:jax._src.xla_bridge:749: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'\n",
|
||||
"INFO:2025-06-06 02:43:16,897:jax._src.xla_bridge:749: Unable to initialize backend 'proxy': INVALID_ARGUMENT: IFRT proxy server address must be '<transport-type>://<backend-address>' (e.g., 'grpc://localhost'), but got \n",
|
||||
"INFO:2025-06-06 02:43:16,900:jax._src.xla_bridge:749: Unable to initialize backend 'mlcr': Could not initialize backend 'mlcr'\n",
|
||||
"INFO:2025-06-06 02:43:16,901:jax._src.xla_bridge:749: Unable to initialize backend 'sliceme': Could not initialize backend 'sliceme'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 4
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "p108c5yIlYH7"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Using existing tools\n",
|
||||
"\n",
|
||||
"If you're familiar with the [sampling](https://gemma-llm.readthedocs.io/en/latest/sampling.html) tutorial, using tool-use differ in two ways:\n",
|
||||
"\n",
|
||||
"1. Using the `gm.text.ToolSampler` rather than the `gm.text.ChatSampler`.\n",
|
||||
"2. Passing the `tools=` you want to use to the sampler.\n",
|
||||
"\n",
|
||||
"For example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"height": 594
|
||||
},
|
||||
"executionInfo": {
|
||||
"elapsed": 50615,
|
||||
"status": "ok",
|
||||
"timestamp": 1749138791069,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -120
|
||||
},
|
||||
"id": "iRCV5h8BlVX6",
|
||||
"outputId": "b3b5d83d-8a8b-4982-fc8f-d409fb8b38a9"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"sampler = gm.text.ToolSampler(\n",
|
||||
" model=model,\n",
|
||||
" params=params,\n",
|
||||
" tools=[\n",
|
||||
" gm.tools.Calculator(),\n",
|
||||
" gm.tools.FileExplorer(),\n",
|
||||
" ],\n",
|
||||
" print_stream=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"output = sampler.chat('I have a serie `Sn+1 = cos(Sn) * 2`. Using the calculator, compute the steps 0-4 for S0 = 3')"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Thought: I need to compute S0, S1, S2, S3, and S4 using the given recursive formula Sn+1 = cos(Sn) * 2 and S0 = 3.\n",
|
||||
"Let's start with S0 = 3.\n",
|
||||
"S1 = cos(S0) * 2 = cos(3) * 2\n",
|
||||
"S2 = cos(S1) * 2 = cos(cos(3) * 2) * 2\n",
|
||||
"S3 = cos(S2) * 2 = cos(cos(cos(3) * 2) * 2) * 2\n",
|
||||
"S4 = cos(S3) * 2 = cos(cos(cos(cos(3) * 2) * 2)) * 2\n",
|
||||
"\n",
|
||||
"I will use the calculator to compute these values.\n",
|
||||
"{\"tool_name\": \"calculator\", \"expression\": \"cos(3) * 2\"}\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Tool result: -1.9799849932008908]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Thought: Now I need to compute S1 = cos(S0) * 2 = cos(3) * 2 = -1.9799849932008908 * 2\n",
|
||||
"{\"tool_name\": \"calculator\", \"expression\": \"-1.9799849932008908 * 2\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Tool result: -3.9599699864017817]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Thought: Now I need to compute S2 = cos(S1) * 2 = cos(-3.9599699864017817) * 2\n",
|
||||
"{\"tool_name\": \"calculator\", \"expression\": \"cos(-3.9599699864017817) * 2\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Tool result: -1.3668134299076982]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Thought: Now I need to compute S3 = cos(S2) * 2 = cos(-1.3668134299076982) * 2\n",
|
||||
"{\"tool_name\": \"calculator\", \"expression\": \"cos(-1.3668134299076982) * 2\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Tool result: 0.4051424976130353]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Thought: Now I need to compute S4 = cos(S3) * 2 = cos(0.4051424976130353) * 2\n",
|
||||
"{\"tool_name\": \"calculator\", \"expression\": \"cos(0.4051424976130353) * 2\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Tool result: 1.8380924822033438]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The values are: S0 = 3, S1 = -3.9599699864017817, S2 = -1.3668134299076982, S3 = 0.4051424976130353, S4 = 1.8380924822033438"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 10
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "FAI54F-Blkan"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Note: Only the final model answer is returned. You can access the conversation history, including all intermediates tool calls and output through `sampler.turns` property."
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "D0_IIS1Nlfuw"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Creating your own tool\n",
|
||||
"\n",
|
||||
"To create your own tool, you can inherit from the `gm.tools.Tool` class. You should provide:\n",
|
||||
"\n",
|
||||
"* A description & example, so the model knows how to use your tool\n",
|
||||
"* Implement the `call` method. The `call` function can take arbitrary `**kwargs`, but the name of the args should match the ones defined in `tool_kwargs` and `tool_kwargs_doc`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"executionInfo": {
|
||||
"elapsed": 55,
|
||||
"status": "ok",
|
||||
"timestamp": 1749203934196,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -120
|
||||
},
|
||||
"id": "XqmQcfdI0oEl"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"class DateTime(gm.tools.Tool):\n",
|
||||
" \"\"\"Tool to access the current date.\"\"\"\n",
|
||||
"\n",
|
||||
" DESCRIPTION = 'Access the current date, time,...'\n",
|
||||
" EXAMPLE = gm.tools.Example(\n",
|
||||
" query='Which day of the week are we today ?',\n",
|
||||
" thought='The `datetime.strptime` uses %a for day of the week',\n",
|
||||
" tool_kwargs={'format': '%a'},\n",
|
||||
" tool_kwargs_doc={'format': '<ANY datetime.strptime expression>'},\n",
|
||||
" result='Sat',\n",
|
||||
" answer='Today is Saturday.',\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def call(self, format: str) -> str:\n",
|
||||
" dt = datetime.datetime.now()\n",
|
||||
" return dt.strftime(format)\n"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 7
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "sSxYhXPuuXYp"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"The tool can then be used in the sampler:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"height": 118
|
||||
},
|
||||
"executionInfo": {
|
||||
"elapsed": 2156,
|
||||
"status": "ok",
|
||||
"timestamp": 1749204833094,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": -120
|
||||
},
|
||||
"id": "9S8xB2B-0cbW",
|
||||
"outputId": "fccc0e89-e922-4184-8b77-800041cdd77e"
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"sampler = gm.text.ToolSampler(\n",
|
||||
" model=model,\n",
|
||||
" params=params,\n",
|
||||
" tools=[\n",
|
||||
" DateTime(),\n",
|
||||
" ],\n",
|
||||
" print_stream=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"output = sampler.chat('Which date are we today ?')"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Thought: I need to get the current date.\n",
|
||||
"{\"tool_name\": \"datetime\", \"format\": \"%Y-%m-%d\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Tool result: 2025-06-06]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<hr>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Today is June 6th, 2025."
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 9
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "esIpCjhxzHmf"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Next steps\n",
|
||||
"\n",
|
||||
"* See our [multimodal](https://gemma-llm.readthedocs.io/en/latest/multimodal.html) example to query the model with images.\n",
|
||||
"* See our [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) example to train Gemma on your custom task.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"last_runtime": {},
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
# Copyright 2026 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""Example config for finetuning Gemma for a classification task.
|
||||
|
||||
* Input: A text to classify.
|
||||
* Output: A classification label. The pre-trained Gemma model is trained to
|
||||
predict one world among 256.000. Here, we're finetuning to predict only 2
|
||||
tokens among the 256.000 available.
|
||||
|
||||
Train locally with:
|
||||
|
||||
```sh
|
||||
python -m kauldron.main \
|
||||
--cfg=examples/classification.py \
|
||||
--cfg.workdir=/tmp/kauldron_oss/workdir
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from kauldron import konfig
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
with konfig.imports():
|
||||
from gemma import gm
|
||||
from kauldron import kd
|
||||
import optax
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
def get_config():
|
||||
"""Get the default hyperparameter configuration."""
|
||||
return kd.train.Trainer(
|
||||
seed=42,
|
||||
# Dataset
|
||||
train_ds=_make_dataset(training=True),
|
||||
# Model definition
|
||||
model=gm.nn.Gemma3_4B(
|
||||
tokens="batch.sentence",
|
||||
return_last_only=True,
|
||||
),
|
||||
# Load the weights from the pretrained checkpoint
|
||||
init_transform=gm.ckpts.LoadCheckpoint(
|
||||
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
|
||||
),
|
||||
# Training
|
||||
num_train_steps=10_000,
|
||||
train_losses={
|
||||
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
|
||||
logits="preds.logits",
|
||||
labels="batch.label",
|
||||
),
|
||||
},
|
||||
optimizer=optax.adafactor(learning_rate=1e-4),
|
||||
checkpointer=kd.ckpts.Checkpointer(
|
||||
save_interval_steps=500,
|
||||
),
|
||||
# Evaluation
|
||||
evals={
|
||||
"test": kd.evals.Evaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
ds=_make_dataset(training=False),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_dataset(training: bool) -> kd.data.Pipeline:
|
||||
# Dict key names from the dataset
|
||||
_INPUT_FIELD = "sentence" # pylint: disable=invalid-name
|
||||
_LABEL_FIELD = "label" # pylint: disable=invalid-name
|
||||
|
||||
tokenizer = gm.text.Gemma3Tokenizer()
|
||||
|
||||
return kd.data.py.Tfds(
|
||||
name="glue/cola",
|
||||
split="train" if training else "validation",
|
||||
shuffle=True if training else False,
|
||||
num_epochs=None if training else 1,
|
||||
batch_size=8,
|
||||
transforms=[
|
||||
# Process the input text
|
||||
# TFDS datasets returns `bytes`, so convert them to `str`
|
||||
gm.data.DecodeBytes(key=_INPUT_FIELD),
|
||||
gm.data.FormatText(
|
||||
key=_INPUT_FIELD,
|
||||
template="""<start_of_turn>user
|
||||
Please classify whether the following sentence is grammaticaly correct, please answer only with Yes or No.
|
||||
Sentence: {text}<end_of_turn>
|
||||
<start_of_turn>model""",
|
||||
),
|
||||
gm.data.Tokenize(
|
||||
key=_INPUT_FIELD,
|
||||
tokenizer=tokenizer,
|
||||
add_bos=True,
|
||||
),
|
||||
gm.data.Pad(
|
||||
key=_INPUT_FIELD,
|
||||
max_length=128,
|
||||
),
|
||||
# Process the label
|
||||
gm.data.MapInts(
|
||||
key=_LABEL_FIELD,
|
||||
# Rather than predicting the token 0 and 1, we are using the
|
||||
# token 1294 and 3553 which respectivelly correspond to "No" and
|
||||
# "Yes". We do this because those token already contain semantic
|
||||
# information, so even zero-shot prediction without any
|
||||
# finetuning has better than random performances.
|
||||
old_to_new={
|
||||
0: 1294, # Token -> "No"
|
||||
1: 3553, # Token -> "Yes"
|
||||
},
|
||||
),
|
||||
kd.data.Rearrange(
|
||||
key=_LABEL_FIELD,
|
||||
pattern="... -> ... 1", # For shape compatibility with the loss.
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,122 @@
|
||||
# Copyright 2026 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""DPO Example.
|
||||
|
||||
DPO works by running two answers (one prefered and one rejected) into both
|
||||
the reference model and the model to finetune. Then the DPO loss is used to
|
||||
increase the likelihood of generating the preferred answer.
|
||||
|
||||
Implementation wise, this is done by:
|
||||
|
||||
* Wrapping the model inside a `gm.nn.AnchoredPolicy` (which runs both the
|
||||
model and the reference frozen model)
|
||||
* Using the `gm.ckpts.AnchoredPolicyLoader` to restore the weights, so the
|
||||
weights are correctly mapped to inside `gm.nn.AnchoredPolicy`.
|
||||
|
||||
|
||||
Train locally with:
|
||||
|
||||
```sh
|
||||
python -m kauldron.main \
|
||||
--cfg=examples/dpo.py \
|
||||
--cfg.workdir=/tmp/kauldron_oss/workdir
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from kauldron import konfig
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
with konfig.imports():
|
||||
from gemma import gm
|
||||
from kauldron import kd
|
||||
import optax
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
def get_config():
|
||||
"""Get the default hyperparameter configuration."""
|
||||
return kd.train.Trainer(
|
||||
seed=42,
|
||||
# Dataset
|
||||
train_ds=_make_dataset(training=True),
|
||||
# Model definition
|
||||
model=gm.nn.AnchoredPolicy(
|
||||
policy=gm.nn.Gemma3_4B(tokens="batch.tokens", text_only=True),
|
||||
),
|
||||
# Load the weights from the pretrained checkpoint
|
||||
init_transform=gm.ckpts.AnchoredPolicyLoader(
|
||||
policy=gm.ckpts.LoadCheckpoint(
|
||||
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
|
||||
),
|
||||
),
|
||||
# Training
|
||||
num_train_steps=10_000,
|
||||
train_losses={
|
||||
"dpo": gm.losses.DpoLoss(
|
||||
tokens="batch.targets",
|
||||
sequence_mask="batch.mask",
|
||||
policy_logits="preds.policy.logits",
|
||||
anchor_logits="preds.anchor.logits",
|
||||
),
|
||||
},
|
||||
optimizer=optax.adafactor(learning_rate=1e-4),
|
||||
checkpointer=kd.ckpts.Checkpointer(
|
||||
save_interval_steps=500,
|
||||
),
|
||||
# Evaluation
|
||||
evals={
|
||||
# "test": kd.evals.Evaluator(
|
||||
# run=kd.evals.EveryNSteps(1000),
|
||||
# ds=_make_dataset(training=False),
|
||||
# ),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_dataset(training: bool) -> kd.data.Pipeline:
|
||||
# TODO(epot): !!!!
|
||||
max_length = 512
|
||||
batch_size = 16
|
||||
|
||||
tokenizer = gm.text.Gemma3Tokenizer()
|
||||
|
||||
return kd.data.py.HuggingFace(
|
||||
path="argilla/distilabel-math-preference-dpo",
|
||||
split="train",
|
||||
shuffle=True if training else False,
|
||||
num_epochs=None if training else 1,
|
||||
batch_size=batch_size,
|
||||
transforms=[
|
||||
# Only keep the fields we need.
|
||||
kd.data.Elements(
|
||||
keep=["instruction", "chosen_response", "rejected_response"]
|
||||
),
|
||||
# Create the model inputs and loss mask.
|
||||
gm.data.ContrastiveTask(
|
||||
in_prompt="instruction",
|
||||
in_chosen="chosen_response",
|
||||
in_rejected="rejected_response",
|
||||
out_tokens="tokens",
|
||||
out_targets="targets",
|
||||
out_mask="mask",
|
||||
tokenizer=tokenizer,
|
||||
# Padding parameters
|
||||
max_length=max_length,
|
||||
# TODO(epot): Run stats (how many examples are we dropping?)
|
||||
truncate=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,154 @@
|
||||
# Copyright 2026 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""Example of Gemma finetuning using LoRA.
|
||||
|
||||
This example is based on the `seq2seq.py` example. See the
|
||||
docstring of that file for more details.
|
||||
|
||||
The changes to use LoRA are:
|
||||
|
||||
* `model`: Use `gm.nn.LoRA()` wrapper to add `LoRA` adapters to the
|
||||
model.
|
||||
* `init_transform`: Use `gm.ckpts.SkipLoRA()` wrapper to only restore the
|
||||
non-LoRA weights.
|
||||
* `optimizer`: Use `kd.optim.partial_updates` wrapper to only train the LoRA
|
||||
weights.
|
||||
|
||||
Train locally with:
|
||||
|
||||
```sh
|
||||
python -m kauldron.main \
|
||||
--cfg=examples/lora.py \
|
||||
--cfg.workdir=/tmp/kauldron_oss/workdir
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from kauldron import konfig
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
with konfig.imports():
|
||||
from gemma import gm
|
||||
from kauldron import kd
|
||||
import optax
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
def get_config():
|
||||
batch_size = 16
|
||||
max_length = 512
|
||||
|
||||
return kd.train.Trainer(
|
||||
seed=42,
|
||||
# Dataset
|
||||
train_ds=_make_dataset(
|
||||
training=True,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
),
|
||||
# Model definition
|
||||
model=gm.nn.LoRA(
|
||||
rank=4,
|
||||
model=gm.nn.Gemma3_4B(
|
||||
tokens="batch.input",
|
||||
# TODO(epot): At the moment, LoRA fine-tuning with multimodal
|
||||
# is not supported. Willbe fixed soon.
|
||||
text_only=True,
|
||||
),
|
||||
),
|
||||
# Load the weights from the pretrained checkpoint
|
||||
# Use `SkipLoRA` as the original checkpoint does not contain the LoRA
|
||||
# weights.
|
||||
init_transform=gm.ckpts.SkipLoRA(
|
||||
wrapped=gm.ckpts.LoadCheckpoint(
|
||||
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
|
||||
)
|
||||
),
|
||||
# Training
|
||||
num_train_steps=10_000,
|
||||
train_losses={
|
||||
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
|
||||
logits="preds.logits",
|
||||
labels="batch.target",
|
||||
mask="batch.loss_mask",
|
||||
),
|
||||
},
|
||||
# TODO(epot): Add Gradient accumenlation.
|
||||
optimizer=kd.optim.partial_updates(
|
||||
optax.adafactor(learning_rate=0.005),
|
||||
# We only optimize the LoRA weights. The rest of the model is frozen.
|
||||
mask=kd.optim.select("lora"),
|
||||
),
|
||||
checkpointer=kd.ckpts.Checkpointer(
|
||||
save_interval_steps=500,
|
||||
),
|
||||
# Evaluation
|
||||
evals={
|
||||
"test": kd.evals.Evaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
ds=_make_dataset(
|
||||
training=False,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
),
|
||||
),
|
||||
# The sampler evaluator run inference on a few prompts from the
|
||||
# test set.
|
||||
"sampling": gm.evals.SamplerEvaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
max_new_tokens=150, # Sampling parameters
|
||||
num_batches=1, # Only predict a single example (batch_size=None)
|
||||
ds=_make_dataset(training=False, sampling=True),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_dataset(
|
||||
*,
|
||||
training: bool,
|
||||
sampling: bool = False,
|
||||
batch_size: int | None = None,
|
||||
max_length: int | None = None,
|
||||
):
|
||||
tokenizer = gm.text.Gemma3Tokenizer()
|
||||
|
||||
return kd.data.py.Tfds(
|
||||
name="mtnt/en-fr",
|
||||
split="train" if training else "test",
|
||||
shuffle=True if training else False,
|
||||
num_epochs=None if training else 1,
|
||||
batch_size=None if sampling else batch_size,
|
||||
num_workers=4,
|
||||
transforms=[
|
||||
# Create the model inputs/targets/loss_mask.
|
||||
gm.data.Seq2SeqTask(
|
||||
# Select which field from the dataset to use.
|
||||
# https://www.tensorflow.org/datasets/catalog/mtnt
|
||||
in_prompt="src",
|
||||
in_response="dst",
|
||||
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
|
||||
out_input="input",
|
||||
out_target="target",
|
||||
out_target_mask="loss_mask",
|
||||
tokenizer=tokenizer,
|
||||
# Padding parameters
|
||||
max_length=None if sampling else max_length,
|
||||
# In this dataset, ~1% of examples are longer than 512 tokens.
|
||||
truncate=True,
|
||||
sampling=sampling,
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,164 @@
|
||||
# Copyright 2026 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""Example of Gemma finetuning for an image captioning task.
|
||||
|
||||
Example:
|
||||
|
||||
Prompt:
|
||||
|
||||
```
|
||||
<start_of_turn>user
|
||||
<start_of_image><end_of_turn>
|
||||
<start_of_turn>model
|
||||
```
|
||||
|
||||
Target:
|
||||
|
||||
```
|
||||
A diagram showing a circuit with a battery, lamp, and switch.<end_of_turn>
|
||||
```
|
||||
|
||||
Here, the prompt only contains the `<start_of_image>` to indicate an image
|
||||
is inserted.
|
||||
|
||||
Train locally with:
|
||||
|
||||
```sh
|
||||
python -m kauldron.main \
|
||||
--cfg=examples/multimodal.py \
|
||||
--cfg.workdir=/tmp/kauldron_oss/workdir
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from kauldron import konfig
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
with konfig.imports():
|
||||
import jax.numpy as jnp
|
||||
from gemma import gm
|
||||
from kauldron import kd
|
||||
import optax
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
def get_config():
|
||||
batch_size = 32
|
||||
max_length = 200
|
||||
|
||||
return kd.train.Trainer(
|
||||
seed=42,
|
||||
# Dataset
|
||||
train_ds=_make_dataset(
|
||||
training=True,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
),
|
||||
# Model definition
|
||||
model=gm.nn.Gemma3_4B(
|
||||
tokens="batch.input",
|
||||
images="batch.image",
|
||||
),
|
||||
# Load the weights from the pretrained checkpoint
|
||||
init_transform=gm.ckpts.LoadCheckpoint(
|
||||
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
|
||||
),
|
||||
# Training
|
||||
num_train_steps=10_000,
|
||||
train_losses={
|
||||
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
|
||||
logits="preds.logits",
|
||||
labels="batch.target",
|
||||
mask="batch.loss_mask",
|
||||
),
|
||||
},
|
||||
train_summaries={
|
||||
"image": kd.summaries.ShowImages(images="batch.image", num_images=5),
|
||||
},
|
||||
optimizer=optax.adafactor(learning_rate=1e-3),
|
||||
checkpointer=kd.ckpts.Checkpointer(
|
||||
save_interval_steps=500,
|
||||
),
|
||||
# Evaluation
|
||||
evals={
|
||||
"test": kd.evals.Evaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
ds=_make_dataset(
|
||||
training=False,
|
||||
batch_size=4,
|
||||
max_length=max_length,
|
||||
),
|
||||
),
|
||||
# The sampler evaluator run inference on a few prompts from the
|
||||
# test set.
|
||||
"sampling": gm.evals.SamplerEvaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
max_new_tokens=50, # Sampling parameters
|
||||
num_batches=3,
|
||||
ds=_make_dataset(training=False, sampling=True),
|
||||
summaries={
|
||||
"image": kd.summaries.ShowImages(
|
||||
images="batch.image", num_images=5
|
||||
),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_dataset(
|
||||
*,
|
||||
training: bool,
|
||||
sampling: bool = False,
|
||||
batch_size: int | None = None,
|
||||
max_length: int | None = None,
|
||||
):
|
||||
tokenizer = gm.text.Gemma3Tokenizer()
|
||||
|
||||
return kd.data.py.Tfds(
|
||||
name="ai2dcaption",
|
||||
split="llava_15" if training else "test",
|
||||
shuffle=True if training else False,
|
||||
num_epochs=None if training else 1,
|
||||
batch_size=None if sampling else batch_size,
|
||||
num_workers=4,
|
||||
transforms=[
|
||||
# Only keep the fields we need.See fields at:
|
||||
# https://www.tensorflow.org/datasets/catalog/ai2dcaption
|
||||
kd.data.Elements(keep=["image", "caption"]),
|
||||
# Create a new constant field
|
||||
kd.data.AddConstants({"prompt": "<start_of_image>"}),
|
||||
# Create the model inputs/targets/loss_mask.
|
||||
gm.data.Seq2SeqTask(
|
||||
# Select which field from the dataset to use.
|
||||
in_prompt="prompt",
|
||||
in_response="caption",
|
||||
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
|
||||
out_input="input",
|
||||
out_target="target",
|
||||
out_target_mask="loss_mask",
|
||||
tokenizer=tokenizer,
|
||||
# Padding parameters
|
||||
max_length=None if sampling else max_length,
|
||||
# In this dataset, ~1% of examples are longer than 512 tokens.
|
||||
truncate=True,
|
||||
sampling=sampling,
|
||||
),
|
||||
kd.data.py.Resize(key="image", size=(800, 800)),
|
||||
# TODO(epot): Make the `num_images` dimension optional
|
||||
kd.data.Rearrange(key="image", pattern="... h w c -> ... 1 h w c"),
|
||||
kd.data.Cast(key="image", dtype=jnp.uint8),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,133 @@
|
||||
# Copyright 2026 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""Example of Gemma finetuning for a prompt -> response task.
|
||||
|
||||
This is a fork of the seq2seq example, but with sharding.
|
||||
The only difference is the `sharding=kd.sharding.ShardingStrategy()`
|
||||
|
||||
Train locally with:
|
||||
|
||||
```sh
|
||||
python -m kauldron.main \
|
||||
--cfg=examples/sharding.py \
|
||||
--cfg.workdir=/tmp/kauldron_oss/workdir
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from kauldron import konfig
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
with konfig.imports():
|
||||
from gemma import gm
|
||||
from kauldron import kd
|
||||
import optax
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
def get_config():
|
||||
batch_size = 16
|
||||
max_length = 512
|
||||
|
||||
return kd.train.Trainer(
|
||||
seed=42,
|
||||
# Dataset
|
||||
train_ds=_make_dataset(
|
||||
training=True,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
),
|
||||
# Model definition
|
||||
model=gm.nn.Gemma3_4B(
|
||||
tokens="batch.input",
|
||||
),
|
||||
sharding=kd.sharding.ShardingStrategy(
|
||||
params=kd.sharding.FSDPSharding(),
|
||||
),
|
||||
# Load the weights from the pretrained checkpoint
|
||||
init_transform=gm.ckpts.LoadCheckpoint(
|
||||
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
|
||||
),
|
||||
# Training
|
||||
num_train_steps=10_000,
|
||||
train_losses={
|
||||
"xentropy": kd.losses.SoftmaxCrossEntropyWithIntLabels(
|
||||
logits="preds.logits",
|
||||
labels="batch.target",
|
||||
mask="batch.loss_mask",
|
||||
),
|
||||
},
|
||||
optimizer=optax.adafactor(learning_rate=1e-3),
|
||||
checkpointer=kd.ckpts.Checkpointer(
|
||||
save_interval_steps=500,
|
||||
),
|
||||
# Evaluation
|
||||
evals={
|
||||
"test": kd.evals.Evaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
ds=_make_dataset(
|
||||
training=False,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
),
|
||||
),
|
||||
# The sampler evaluator run inference on a few prompts from the
|
||||
# test set.
|
||||
"sampling": gm.evals.SamplerEvaluator(
|
||||
run=kd.evals.EveryNSteps(1000),
|
||||
max_new_tokens=50, # Sampling parameters
|
||||
num_batches=1, # Only predict a single example (batch_size=None)
|
||||
ds=_make_dataset(training=False, sampling=True),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _make_dataset(
|
||||
*,
|
||||
training: bool,
|
||||
sampling: bool = False,
|
||||
batch_size: int | None = None,
|
||||
max_length: int | None = None,
|
||||
):
|
||||
tokenizer = gm.text.Gemma3Tokenizer()
|
||||
|
||||
return kd.data.py.Tfds(
|
||||
name="mtnt/en-fr",
|
||||
split="train" if training else "test",
|
||||
shuffle=True if training else False,
|
||||
num_epochs=None if training else 1,
|
||||
batch_size=None if sampling else batch_size,
|
||||
num_workers=4,
|
||||
transforms=[
|
||||
# Create the model inputs/targets/loss_mask.
|
||||
gm.data.Seq2SeqTask(
|
||||
# Select which field from the dataset to use.
|
||||
# https://www.tensorflow.org/datasets/catalog/mtnt
|
||||
in_prompt="src",
|
||||
in_response="dst",
|
||||
# Output batch is {"input": ..., "target": ..., "loss_mask": ...}
|
||||
out_input="input",
|
||||
out_target="target",
|
||||
out_target_mask="loss_mask",
|
||||
tokenizer=tokenizer,
|
||||
# Padding parameters
|
||||
max_length=None if sampling else max_length,
|
||||
# In this dataset, ~1% of examples are longer than 512 tokens.
|
||||
truncate=True,
|
||||
sampling=sampling,
|
||||
),
|
||||
],
|
||||
)
|
||||
Reference in New Issue
Block a user