{ "cells": [ { "metadata": { "id": "My2AZpV_RuTs" }, "cell_type": "markdown", "source": [ "# Sampling\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/sampling.ipynb)\n", "\n", "Example on how to load a Gemma model and run inference on it.\n", "\n", "The Gemma library has 3 ways to prompt a model:\n", "\n", "* `gm.text.ChatSampler`: Easiest to use, simply talk to the model and get answer. Support multi-turns conversations out-of-the-box.\n", "* `gm.text.Sampler`: Lower level, but give more control. The chat state has to be manually handeled for multi-turn.\n", "* `model.apply`: Directly call the model, only predict a single token." ] }, { "metadata": { "id": "94CVV9ZxKVDO" }, "cell_type": "code", "source": [ "!pip install -q gemma" ], "outputs": [], "execution_count": null }, { "metadata": { "executionInfo": { "elapsed": 2610, "status": "ok", "timestamp": 1741517119876, "user": { "displayName": "", "userId": "" }, "user_tz": -60 }, "id": "PXBc1hRKRuTt" }, "cell_type": "code", "source": [ "# Common imports\n", "import os\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "# Gemma imports\n", "from gemma import gm" ], "outputs": [], "execution_count": 1 }, { "metadata": { "id": "i_DEVehe3v5Y" }, "cell_type": "markdown", "source": [ "By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):" ] }, { "metadata": { "id": "AaK17GWo3v5Y" }, "cell_type": "code", "source": [ "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\"" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "XzEj3PYvX1Sm" }, "cell_type": "markdown", "source": [ "Load the model and the params. Here we load the instruction-tuned version of the model." ] }, { "metadata": { "id": "ox1CAuffKJtj" }, "cell_type": "code", "source": [ "model = gm.nn.Gemma3_4B()\n", "\n", "params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "iRjnbNREdugN" }, "cell_type": "markdown", "source": [ "## Multi-turns conversations\n", "\n", "The easiest way to chat with Gemma is to use the `gm.text.ChatSampler`. It hides the boilerplate of the conversation cache, as well as the `` / `` tokens used to format the conversation.\n", "\n", "Here, we set `multi_turn=True` when creating `gm.text.ChatSampler` (by default, the `ChatSampler` start a new conversation every time).\n", "\n", "In multi-turn mode, you can erase the previous conversation state, by passing `chatbot.chat(..., multi_turn=False)`." ] }, { "metadata": { "executionInfo": { "elapsed": 18237, "status": "ok", "timestamp": 1741517159587, "user": { "displayName": "", "userId": "" }, "user_tz": -60 }, "id": "bGSE6aTYdxqt", "outputId": "5d428ab7-f9bf-4898-a921-cd57353ff364" }, "cell_type": "code", "source": [ "sampler = gm.text.ChatSampler(\n", " model=model,\n", " params=params,\n", " multi_turn=True,\n", " print_stream=True, # Print output as it is generated.\n", ")\n", "\n", "turn0 = sampler.chat('Share one methapore linking \"shadow\" and \"laughter\".')" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Okay, here's a metaphor linking \"shadow\" and \"laughter,\" aiming for a slightly evocative and layered feel:\n", "\n", "**\"Laughter is the fleeting shadow of joy, dancing across a face that’s often hidden in the long shadow of sorrow.\"**\n", "\n", "---\n", "\n", "**Here's a breakdown of why this works:**\n", "\n", "* **\"Shadow\"** represents sadness, pain, or a past experience that lingers. It’s not necessarily a dark shadow, but a persistent presence.\n", "* **\"Laughter\"** is presented as a brief, bright appearance – a momentary flash of happiness.\n", "* **\"Dancing across a face that’s often hidden\"** emphasizes that the joy isn't constant, and the underlying sadness is still there, obscuring it.\n", "\n", "---\n", "\n", "Would you like me to:\n", "\n", "* Try a different type of metaphor?\n", "* Expand on this one with a short story snippet?\n" ] } ], "execution_count": 3 }, { "metadata": { "executionInfo": { "elapsed": 7822, "status": "ok", "timestamp": 1741517170597, "user": { "displayName": "", "userId": "" }, "user_tz": -60 }, "id": "MIZqgSSRfmRS", "outputId": "575490ff-3d07-48f6-cb39-5fac3553e24a" }, "cell_type": "code", "source": [ "turn1 = sampler.chat('Expand it in a haiku.')" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Okay, here’s a haiku based on the metaphor:\n", "\n", "Shadow stretches long,\n", "Laughter’s brief, bright, dancing grace,\n", "Joy hides in the dark. \n", "\n", "---\n", "\n", "Would you like me to try another haiku, or perhaps a different poetic form?\n" ] } ], "execution_count": 4 }, { "metadata": { "id": "uOQHd6eFd67c" }, "cell_type": "markdown", "source": [ "Note: By default (`multi_turn=False`), the conversation state is reset everytime, but you can still continue the previous conversation by passing `sampler.chat(..., multi_turn=True)`\n", "\n", "By default, greedy decoding is used. You can pass a custom `sampling=` method as kwargs:\n", "\n", "* `gm.text.Greedy()`: (default) Greedy decoding\n", "* `gm.text.RandomSampling()`: Simple random sampling with temperature, for more variety" ] }, { "metadata": { "id": "AH0eWFWJaiNk" }, "cell_type": "markdown", "source": [ "## Sample a prompt\n", "\n", "For more control, we also provide a `gm.text.Sampler` which still perform efficient sampling (with kv-caching, early stopping,...).\n", "\n", "Prompting the sampler require to correctly add format the prompt with the `` / `` tokens (see the custom token section doc on [tokenizer](https://gemma-llm.readthedocs.io/en/latest/tokenizer.html))." ] }, { "metadata": { "executionInfo": { "elapsed": 14339, "status": "ok", "timestamp": 1741277003042, "user": { "displayName": "", "userId": "" }, "user_tz": -60 }, "id": "6R5J42EiZtkC", "outputId": "e309909f-b627-4057-94c0-524e66dd3ea2" }, "cell_type": "code", "source": [ "sampler = gm.text.Sampler(\n", " model=model,\n", " params=params,\n", ")\n", "\n", "prompt = \"\"\"user\n", "Give me a list of inspirational quotes.\n", "model\n", "\"\"\"\n", "\n", "out = sampler.sample(prompt, max_new_tokens=1000)\n", "print(out)" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Okay, here's a list of inspirational quotes, categorized a little to give you a variety:\n", "\n", "**On Perseverance & Resilience:**\n", "\n", "* “The only way to do great work is to love what you do.” – Steve Jobs\n", "* “Fall seven times, stand up eight.” – Japanese Proverb\n", "* “The difference between ordinary and extraordinary is that little extra.” – Jimmy Johnson\n", "* “Success is not final, failure is not fatal: It is the courage to continue that counts.” – Winston Churchill\n", "* “Don’t watch the clock; do what it does. Keep going.” – Sam Levenson\n", "* “When the going gets tough, the tough get going.” – Theodore Roosevelt\n", "\n", "\n", "**On Self-Love & Confidence:**\n", "\n", "* “You are enough.” – Brené Brown\n", "* “Believe you can and you’re halfway there.” – Theodore Roosevelt\n", "* “You must be the change you wish to see in the world.” – Mahatma Gandhi\n", "* “The best is yet to come.” – Frank Sinatra\n", "* “Be the energy you want to attract.” – Tony Gaskins\n", "* “Don’t be defined by your past. Define your future.” – Unknown\n", "\n", "\n", "**On Dreams & Goals:**\n", "\n", "* “If you can dream it, you can do it.” – Walt Disney\n", "* “The future belongs to those who believe in the beauty of their dreams.” – Eleanor Roosevelt\n", "* “Shoot for the moon. Even if you miss, you’ll land among the stars.” – Les Brown\n", "* “Start where you are. Use what you have. Do what you can.” – Arthur Ashe\n", "* “Life begins at the end of your comfort zone.” – Unknown\n", "\n", "\n", "**On Happiness & Perspective:**\n", "\n", "* “Happiness is not something readymade. It comes from your own actions.” – Dalai Lama\n", "* “It’s not the triumph that matters, it’s the effort.” – Winston Churchill\n", "* “Don’t wait for the perfect moment, take the moment and make it perfect.” – Oscar Wilde\n", "* “Be present. Be grateful. Be you.” – Unknown\n", "* “The only way out is through.” – Robert Frost\n", "\n", "\n", "\n", "**Short & Powerful:**\n", "\n", "* “Be the change.” – Mahatma Gandhi\n", "* “Just breathe.”\n", "* “Keep going.”\n", "* “You got this.”\n", "* “Dream big.”\n", "\n", "---\n", "\n", "**Resources for More Quotes:**\n", "\n", "* **BrainyQuote:** [https://www.brainyquote.com/](https://www.brainyquote.com/)\n", "* **Goodreads:** [https://www.goodreads.com/quotes](https://www.goodreads.com/quotes)\n", "* **Quote Garden:** [https://quotegarden.com/](https://quotegarden.com/)\n", "\n", "To help me give you even more relevant quotes, could you tell me:\n", "\n", "* **What kind of inspiration are you looking for?** (e.g., motivation for work, overcoming challenges, self-love, etc.)\n", "* **Is there a particular theme or topic you'd like quotes about?**\n" ] } ], "execution_count": null }, { "metadata": { "id": "TPf01271RuTs" }, "cell_type": "markdown", "source": [ "## Use the model directly\n", "\n", "Here's an example of predicting a single token, directly calling the model.\n", "\n", "The model input expectes encoded tokens. For this, we first need to encode the prompt with our tokenizer. See our [tokenizer](https://gemma-llm.readthedocs.io/en/latest/tokenizer.html) documentation for more information on using the tokenizer." ] }, { "metadata": { "id": "mvCAQCDXZ0D3" }, "cell_type": "code", "source": [ "tokenizer = gm.text.Gemma3Tokenizer()" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "kH534DHohG67" }, "cell_type": "markdown", "source": [ "Note: When encoding the prompt, don't forget to add the beginning-of-string token with `add_bos=True`. All prompts feed to the model should start by this token." ] }, { "metadata": { "id": "T1GC0OPRhHGc" }, "cell_type": "code", "source": [ "prompt = tokenizer.encode('One word to describe Paris: \\n\\n', add_bos=True)\n", "prompt = jnp.asarray(prompt)" ], "outputs": [], "execution_count": null }, { "metadata": { "id": "hHqyFnskZ5SC" }, "cell_type": "markdown", "source": [ "We then can call the model, and get the predicted logits." ] }, { "metadata": { "colab": { "height": 35 }, "executionInfo": { "elapsed": 7194, "status": "ok", "timestamp": 1741277771722, "user": { "displayName": "", "userId": "" }, "user_tz": -60 }, "id": "G3Kbo9hgRuTt", "outputId": "3b0e468e-c55c-466a-ece9-ade04926612d" }, "cell_type": "code", "source": [ "# Run the model\n", "out = model.apply(\n", " {'params': params},\n", " tokens=prompt,\n", " return_last_only=True, # Only predict the last token\n", ")\n", "\n", "\n", "# Sample a token from the predicted logits\n", "next_token = jax.random.categorical(\n", " jax.random.key(1),\n", " out.logits\n", ")\n", "tokenizer.decode(next_token)" ], "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'Romantic'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "execution_count": null }, { "metadata": { "id": "TNb88k3EF5gu" }, "cell_type": "markdown", "source": [ "You can also display the next token probability." ] }, { "metadata": { "colab": { "height": 542 }, "executionInfo": { "elapsed": 104, "status": "ok", "timestamp": 1741277773577, "user": { "displayName": "", "userId": "" }, "user_tz": -60 }, "id": "mIkOdE9dF45s", "outputId": "ab610392-44f5-44cc-9370-bd881bc9136b" }, "cell_type": "code", "source": [ "tokenizer.plot_logits(out.logits)" ], "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "
\n", "
\n", "\n", "" ] }, "metadata": {}, "output_type": "display_data" } ], "execution_count": null }, { "metadata": { "id": "3yeUUP8Rh5mU" }, "cell_type": "markdown", "source": [ "## Next steps\n", "\n", "* See our [multimodal](https://gemma-llm.readthedocs.io/en/latest/multimodal.html) example to query the model with images.\n", "* See our [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) example to train Gemma on your custom task.\n", "* See our [tool use](https://gemma-llm.readthedocs.io/en/latest/tool_use.html) tutorial to extend Gemma with external tools.\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "last_runtime": {}, "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }