Files
gemma4-research/tooling/google-official/cookbook/tutorials_RAG_EmbeddingGemma.ipynb
Mortdecai eecebe7ef5 docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own
README indexing downloaded files with verified upstream sources.

- google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts,
  gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev
  HTML snapshots, Gemma 3 tech report
- huggingface/: 8 gemma-4-* model cards, chat-template .jinja files,
  tokenizer_config.json, transformers gemma4/ source, launch blog posts,
  official HF Spaces app.py
- inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI
  comparison, run_commands.sh with 8 working launches, 9 code snippets
- gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2,
  Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma)
- fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE),
  TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md

Findings that update earlier CORPUS_* docs are flagged in tooling/README.md
(not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch
abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM,
FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech
report PDF yet, no Gemma-4-generation specialized siblings yet.

Pre-commit secrets hook bypassed per user authorization — flagged "secrets"
are base64 notebook cell outputs and example Ed25519 keys in the HDP
agentic-security demo, not real credentials.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-18 12:24:48 -04:00

926 lines
35 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "-u7xRR3DeFXz"
},
"source": [
"##### Copyright 2026 Google LLC."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "oed1Dh9SeIlD"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A0UbyyBOeKmV"
},
"source": [
"# RAG with EmbeddingGemma\n",
"\n",
"<table align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemma/cookbook/blob/main/tutorials/RAG_with_EmbeddingGemma.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ND35JUp9ecq2"
},
"source": [
"EmbeddingGemma is a lightweight, open embedding model designed for fast, high-quality retrieval on everyday devices like mobile phones. At only 308 million parameters, it's efficient enough to run advanced AI techniques, such as Retrieval Augmented Generation (RAG), directly on your local machine with no internet connection required.\n",
"\n",
"## Setup\n",
"\n",
"Before starting this tutorial, complete the following steps:\n",
"\n",
"* Get access to EmbeddingGemma by logging into [Hugging Face](https://huggingface.co/google/embeddinggemma-300M) and selecting **Acknowledge license** for a Gemma model.\n",
"* Select a Colab runtime with sufficient resources to run\n",
" the Gemma model size you want to run. [Learn more](https://ai.google.dev/gemma/docs/core#sizes).\n",
"* Generate a Hugging Face [Access Token](https://huggingface.co/docs/hub/en/security-tokens#how-to-manage-user-access-token) and use it to login from Colab.\n",
"\n",
"This notebook will run on an NVIDIA T4 GPU."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SZ8cw1nPf-NV"
},
"source": [
"### Install Python packages\n",
"\n",
"Install the libraries required for running the EmbeddingGemma model and generating embeddings. Sentence Transformers is a Python framework for text and image embeddings. For more information, see the [Sentence Transformers](https://www.sbert.net/) documentation."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "daXx6O20Q7M0"
},
"outputs": [],
"source": [
"!pip install -q -U sentence-transformers transformers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kYiTsNFSjGJH"
},
"source": [
"After you have accepted the license, you need a valid Hugging Face Token to access the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eLagJ9aff9Ks"
},
"outputs": [],
"source": [
"# Login into Hugging Face Hub\n",
"from huggingface_hub import login\n",
"login()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IiDcW_rmHBfx"
},
"source": [
"### Load language model\n",
"\n",
"You will use Gemma 4 E2B to generate responses."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "HX2JFDQI-vg8"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c0b54b8b91da46fdb7ba8fd3aecb5002",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4291694230e74608a2808adde451bd0f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/10.2G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb31547f287441aba370d8e7a5fc351e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading weights: 0%| | 0/1951 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0900cc228bed472094eb986719edfde4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"generation_config.json: 0%| | 0.00/208 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3d195cea1ce044f4827cf06412aed5ec",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3bdb49b389aa4abfbb382fccaceb32be",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0%| | 0.00/32.2M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "93e44e5dd0fe40d49e0cda367d98aeca",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"chat_template.jinja: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Load Gemma\n",
"from transformers import pipeline\n",
"\n",
"MODEL_ID = \"google/gemma-4-E2B-it\"\n",
"\n",
"pipeline = pipeline(\n",
" task=\"text-generation\",\n",
" model=MODEL_ID,\n",
" device_map=\"auto\",\n",
" dtype=\"auto\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eAg-c23Wh0th"
},
"source": [
"### Load embedding model\n",
"\n",
"Use the `sentence-transformers` libraries to create an instance of a model class with EmbeddingGemma."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "6Jj1WiTSRRk-"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2c5dc65f501e402fb5ec67d094d925e7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"modules.json: 0%| | 0.00/573 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "10b836de41a0410d8963be637ffa6b9d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config_sentence_transformers.json: 0%| | 0.00/997 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "68e29095344e4d24ac3898638f5a2b0e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"README.md: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "376438be53e14e4b808ce63de0d32cb2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"sentence_bert_config.json: 0%| | 0.00/58.0 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7f2a5a56690e4ed5950ad0c278cc20c7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "264f0c21602640bd9ddfa9d405b5613f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/1.21G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "70eb603cffa948cc895046a8238abbae",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading weights: 0%| | 0/314 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "aa608efe38f448898f8a01940a3684df",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b12b1756d9ac4145ae70595454e0e036",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0%| | 0.00/33.4M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e6e735942c07444ebfcf2702673762b6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"added_tokens.json: 0%| | 0.00/35.0 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ff92bb744fd54211b20f04aedebaa26d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"special_tokens_map.json: 0%| | 0.00/662 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "114a2560d2124889932f1a6436c4d6ef",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/312 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "292f471e215d4ac8a490508ce6963b01",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/134 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dce2f7bc57134d0180f3accdec8d5556",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"2_Dense/model.safetensors: 0%| | 0.00/9.44M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c245d417dc9f4d71850853a107379b16",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/134 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "18def66743ae4738b940a4b20c434545",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"3_Dense/model.safetensors: 0%| | 0.00/9.44M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Device: cuda:0\n",
"SentenceTransformer(\n",
" (0): Transformer({'transformer_task': 'feature-extraction', 'modality_config': {'text': {'method': 'forward', 'method_output_name': 'last_hidden_state'}}, 'module_output_name': 'token_embeddings', 'architecture': 'Gemma3TextModel'})\n",
" (1): Pooling({'embedding_dimension': 768, 'pooling_mode': 'mean', 'include_prompt': True})\n",
" (2): Dense({'in_features': 768, 'out_features': 3072, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity', 'module_input_name': 'sentence_embedding', 'module_output_name': 'sentence_embedding'})\n",
" (3): Dense({'in_features': 3072, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity', 'module_input_name': 'sentence_embedding', 'module_output_name': 'sentence_embedding'})\n",
" (4): Normalize({})\n",
")\n",
"Total number of parameters in the model: 307581696\n"
]
}
],
"source": [
"import torch\n",
"from sentence_transformers import SentenceTransformer\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"model_id = \"google/embeddinggemma-300M\"\n",
"model = SentenceTransformer(model_id).to(device=device)\n",
"\n",
"print(f\"Device: {model.device}\")\n",
"print(model)\n",
"print(\"Total number of parameters in the model:\", sum([p.numel() for _, p in model.named_parameters()]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8o2-nOX-aqRS"
},
"source": [
"### Using Prompts with EmbeddingGemma\n",
"\n",
"For RAG systems, use the following `prompt_name` values to create specialized embeddings for your queries and documents:\n",
"\n",
"* **For Queries:** Use `prompt_name=\"Retrieval-query\"`.<br>\n",
" ```python\n",
" query_embedding = model.encode(\n",
" \"How do I use prompts with this model?\",\n",
" prompt_name=\"Retrieval-query\"\n",
" )\n",
" ```\n",
"\n",
"* **For Documents:** Use `prompt_name=\"Retrieval-document\"`. To further improve document embeddings, you can also include a title by using the `prompt` argument directly:<br>\n",
" * **With a title:**<br>\n",
" ```python\n",
" doc_embedding = model.encode(\n",
" \"The document text...\",\n",
" prompt=\"title: Using Prompts in RAG | text: \"\n",
" )\n",
" ```\n",
" * **Without a title:**<br>\n",
" ```python\n",
" doc_embedding = model.encode(\n",
" \"The document text...\",\n",
" prompt=\"title: none | text: \"\n",
" )\n",
" ```\n",
"\n",
"### Further Reading\n",
"\n",
"* For details on all available EmbeddingGemma prompts, see the [model card](http://ai.google.dev/gemma/docs/embeddinggemma/model_card#prompt_instructions).\n",
"* For general information on prompt templates, see the [Sentence Transformer documentation](https://sbert.net/examples/sentence_transformer/applications/computing-embeddings/README.html#prompt-templates).\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "Y5hVNF3F-qZ7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Available tasks:\n",
" query: \"task: search result | query: \"\n",
" document: \"title: none | text: \"\n",
" BitextMining: \"task: search result | query: \"\n",
" Clustering: \"task: clustering | query: \"\n",
" Classification: \"task: classification | query: \"\n",
" InstructionRetrieval: \"task: code retrieval | query: \"\n",
" MultilabelClassification: \"task: classification | query: \"\n",
" PairClassification: \"task: sentence similarity | query: \"\n",
" Reranking: \"task: search result | query: \"\n",
" Retrieval: \"task: search result | query: \"\n",
" Retrieval-query: \"task: search result | query: \"\n",
" Retrieval-document: \"title: none | text: \"\n",
" STS: \"task: sentence similarity | query: \"\n",
" Summarization: \"task: summarization | query: \"\n"
]
}
],
"source": [
"print(\"Available tasks:\")\n",
"for name, prefix in model.prompts.items():\n",
" print(f\" {name}: \\\"{prefix}\\\"\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eIfWZ_z3xDZq"
},
"source": [
"## Simple RAG example\n",
"\n",
"Retrieval is the task of finding the most relevant pieces of information from a large collection (a database, a set of documents, a website) based on the meaning of a query, not just keywords.\n",
"\n",
"Imagine you work for a company, and you need to find information from the internal employee handbook, which is stored as a collection of hundreds of documents."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"cellView": "form",
"id": "fbaiy-CXRAs7"
},
"outputs": [],
"source": [
"#@title Corp knowledge base\n",
"corp_knowledge_base = [\n",
" {\n",
" \"category\": \"HR & Leave Policies\",\n",
" \"documents\": [\n",
" {\n",
" \"title\": \"Procedure for Unscheduled Absence\",\n",
" \"content\": \"In the event of an illness or emergency preventing you from working, please notify both your direct manager and the HR department via email by 9:30 AM JST. The subject line should be 'Sick Leave - [Your Name]'. If the absence extends beyond two consecutive days, a doctor's certificate (診断書) will be required upon your return.\"\n",
" },\n",
" {\n",
" \"title\": \"Annual Leave Policy\",\n",
" \"content\": \"Full-time employees are granted 10 days of annual paid leave in their first year. This leave is granted six months after the date of joining and increases each year based on length of service. For example, an employee in their third year of service is entitled to 14 days per year. For a detailed breakdown, please refer to the attached 'Annual Leave Accrual Table'.\"\n",
" },\n",
" ]\n",
" },\n",
" {\n",
" \"category\": \"IT & Security\",\n",
" \"documents\": [\n",
" {\n",
" \"title\": \"Account Password Management\",\n",
" \"content\": \"If you have forgotten your password or your account is locked, please use the self-service reset portal at https://reset.ourcompany. You will be prompted to answer your pre-configured security questions. For security reasons, the IT Help Desk cannot reset passwords over the phone or email. If you have not set up your security questions, please visit the IT support desk on the 12th floor of the Shibuya office with your employee ID card.\"\n",
" },\n",
" {\n",
" \"title\": \"Software Procurement Process\",\n",
" \"content\": \"All requests for new software must be submitted through the 'IT Service Desk' portal under the 'Software Request' category. Please include a business justification for the request. All software licenses require approval from your department head before procurement can begin. Please note that standard productivity software is pre-approved and does not require this process.\"\n",
" },\n",
" ]\n",
" },\n",
" {\n",
" \"category\": \"Finance & Expenses\",\n",
" \"documents\": [\n",
" {\n",
" \"title\": \"Expense Reimbursement Policy\",\n",
" \"content\": \"To ensure timely processing, all expense claims for a given month must be submitted for approval no later than the 5th business day of the following month. For example, all expenses incurred in July must be submitted by the 5th business day of August. Submissions after this deadline may be processed in the next payment cycle.\"\n",
" },\n",
" {\n",
" \"title\": \"Business Trip Expense Guidelines\",\n",
" \"content\": \"Travel expenses for business trips will, as a rule, be reimbursed based on the actual cost of the most logical and economical route. Please submit a travel expense application in advance when using the Shinkansen or airplanes. Taxis are permitted only when public transportation is unavailable or when transporting heavy equipment. Receipts are mandatory.\"\n",
" },\n",
" ]\n",
" },\n",
" {\n",
" \"category\": \"Office & Facilities\",\n",
" \"documents\": [\n",
" {\n",
" \"title\": \"Conference Room Booking Instructions\",\n",
" \"content\": \"All conference rooms in the Shibuya office can be reserved through your Calendar App. Create a new meeting invitation, add the attendees, and then use the 'Room Finder' feature to select an available room. Please be sure to select the correct floor. For meetings with more than 10 people, please book the 'Sakura' or 'Fuji' rooms on the 14th floor.\"\n",
" },\n",
" {\n",
" \"title\": \"Mail and Delivery Policy\",\n",
" \"content\": \"The company's mail services are intended for business-related correspondence only. For security and liability reasons, employees are kindly requested to refrain from having personal parcels or mail delivered to the Shibuya office address. The front desk will not be able to accept or hold personal deliveries.\"\n",
" },\n",
" ]\n",
" },\n",
"]\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fvecfoko--hL"
},
"source": [
"And imagine you have a question like below."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "wN-WHf26J89m"
},
"outputs": [],
"source": [
"question = \"How do I reset my password?\" # @param [\"How many days of annual paid leave do I get?\", \"How do I reset my password?\", \"What travel expenses can be reimbursed for a business trip?\", \"Can I receive personal packages at the office?\"] {type:\"string\", allow-input: true}\n",
"\n",
"# Define a minimum confidence threshold for a match to be considered valid\n",
"similarity_threshold = 0.4 # @param {\"type\":\"slider\",\"min\":0,\"max\":1,\"step\":0.1}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2CSeSmF7OuMB"
},
"source": [
"Search relevant document from the corporate knowledge base."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "NngqWUxOyrLS"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 1: Finding the best category...\n",
"['HR & Leave Policies', 'IT & Security', 'Finance & Expenses', 'Office & Facilities']\n",
"tensor([[0.5063, 0.5937, 0.5076, 0.4221]])\n",
" `-> ✅ Category Found: 'IT & Security' (Score: 0.59)\n",
"\n",
"Step 2: Finding the best document in that category...\n",
"['Account Password Management', 'Software Procurement Process']\n",
"tensor([[0.5829, 0.1531]])\n",
" `-> ✅ Document Found: 'Account Password Management' (Score: 0.58)\n"
]
}
],
"source": [
"# --- Helper Functions for Semantic Search ---\n",
"\n",
"def _calculate_best_match(similarities):\n",
" print(similarities)\n",
" if similarities is None or similarities.nelement() == 0:\n",
" return None, 0.0\n",
"\n",
" # Find the index and value of the highest score\n",
" best_index = similarities.argmax().item()\n",
" best_score = similarities[0, best_index].item()\n",
"\n",
" return best_index, best_score\n",
"\n",
"def find_best_category(model, query, candidates):\n",
" \"\"\"\n",
" Finds the most relevant category from a list of candidates.\n",
"\n",
" Args:\n",
" model: The SentenceTransformer model.\n",
" query: The user's query string.\n",
" candidates: A list of category name strings.\n",
"\n",
" Returns:\n",
" A tuple containing the index of the best category and its similarity score.\n",
" \"\"\"\n",
" if not candidates:\n",
" return None, 0.0\n",
"\n",
" # Encode the query and candidate categories for classification\n",
" query_embedding = model.encode(query, prompt_name=\"Classification\")\n",
" candidate_embeddings = model.encode(candidates, prompt_name=\"Classification\")\n",
"\n",
" print(candidates)\n",
" return _calculate_best_match(model.similarity(query_embedding, candidate_embeddings))\n",
"\n",
"def find_best_doc(model, query, candidates):\n",
" \"\"\"\n",
" Finds the most relevant document from a list of candidates.\n",
"\n",
" Args:\n",
" model: The SentenceTransformer model.\n",
" query: The user's query string.\n",
" candidates: A list of document dictionaries, each with 'title' and 'content'.\n",
"\n",
" Returns:\n",
" A tuple containing the index of the best document and its similarity score.\n",
" \"\"\"\n",
" if not candidates:\n",
" return None, 0.0\n",
"\n",
" # Encode the query for retrieval\n",
" query_embedding = model.encode(query, prompt_name=\"Retrieval-query\")\n",
"\n",
" # Encode the document for similarity check\n",
" doc_texts = [\n",
" f\"title: {doc.get('title', 'none')} | text: {doc.get('content', '')}\"\n",
" for doc in candidates\n",
" ]\n",
" candidate_embeddings = model.encode(doc_texts)\n",
"\n",
" print([doc['title'] for doc in candidates])\n",
"\n",
" # Calculate cosine similarity\n",
" return _calculate_best_match(model.similarity(query_embedding, candidate_embeddings))\n",
"\n",
"# --- Main Search Logic ---\n",
"\n",
"# In your application, `best_document` would result from a search.\n",
"# We initialize it to None to ensure it always exists.\n",
"best_document = None\n",
"\n",
"# 1. Find the most relevant category\n",
"print(\"Step 1: Finding the best category...\")\n",
"categories = [item[\"category\"] for item in corp_knowledge_base]\n",
"best_category_index, category_score = find_best_category(\n",
" model, question, categories\n",
")\n",
"\n",
"# Check if the category score meets the threshold\n",
"if category_score < similarity_threshold:\n",
" print(f\" `-> 🤷 No relevant category found. The highest score was only {category_score:.2f}.\")\n",
"else:\n",
" best_category = corp_knowledge_base[best_category_index]\n",
" print(f\" `-> ✅ Category Found: '{best_category['category']}' (Score: {category_score:.2f})\")\n",
"\n",
" # 2. Find the most relevant document ONLY if a good category was found\n",
" print(\"\\nStep 2: Finding the best document in that category...\")\n",
" best_document_index, document_score = find_best_doc(\n",
" model, question, best_category[\"documents\"]\n",
" )\n",
"\n",
" # Check if the document score meets the threshold\n",
" if document_score < similarity_threshold:\n",
" print(f\" `-> 🤷 No relevant document found. The highest score was only {document_score:.2f}.\")\n",
" else:\n",
" best_document = best_category[\"documents\"][best_document_index]\n",
" # 3. Display the final successful result\n",
" print(f\" `-> ✅ Document Found: '{best_document['title']}' (Score: {document_score:.2f})\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zK9T5rRGAMDw"
},
"source": [
"Next, generate the answer with the retrieved context"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "FrwKySpMASpt"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question🙋‍♂️: How do I reset my password?\n",
"Using document: Account Password Management\n",
"Answer🤖: Please use the self-service reset portal at https://reset.ourcompany. You will be prompted to answer your pre-configured security questions.\n"
]
}
],
"source": [
"from transformers import GenerationConfig\n",
"MODEL_ID = \"google/gemma-4-E2B-it\"\n",
"config = GenerationConfig.from_pretrained(MODEL_ID)\n",
"config.max_new_tokens = 512\n",
"\n",
"qa_prompt_template = \"\"\"Answer the following QUESTION based only on the CONTEXT provided. If the answer cannot be found in the CONTEXT, write \"I don't know.\"\n",
"\n",
"---\n",
"CONTEXT:\n",
"{context}\n",
"---\n",
"QUESTION:\n",
"{question}\n",
"\"\"\"\n",
"\n",
"# First, check if a valid document was found before proceeding.\n",
"if best_document and \"content\" in best_document:\n",
" # If the document exists and has a \"content\" key, generate the answer.\n",
" context = best_document[\"content\"]\n",
"\n",
" prompt = qa_prompt_template.format(context=context, question=question)\n",
"\n",
" messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [{\"type\": \"text\", \"text\": prompt}],\n",
" },\n",
" ]\n",
"\n",
" print(\"Question🙋‍♂️: \" + question)\n",
" # This part assumes your pipeline and response parsing logic are correct\n",
" answer = pipeline(messages, generation_config=config)[0][\"generated_text\"][1][\"content\"]\n",
" print(\"Using document: \" + best_document[\"title\"])\n",
" print(\"Answer🤖: \" + answer)\n",
"\n",
"else:\n",
" # If best_document is None or doesn't have content, give a direct response.\n",
" print(\"Question🙋‍♂️: \" + question)\n",
" print(\"Answer🤖: I'm sorry, I could not find a relevant document to answer that question.\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h4J4pFA3IK1d"
},
"source": [
"## Summary and next steps\n",
"\n",
"You have now learned how to build a practical RAG system with EmbeddingGemma.\n",
"\n",
"Explore what more you can do with EmbeddingGemma:\n",
"\n",
"* [Generate embeddings with Sentence Transformers](https://ai.google.dev/gemma/docs/embeddinggemma/inference-embeddinggemma-with-sentence-transformers)\n",
"* [Fine-tune EmbeddingGemma](https://ai.google.dev/gemma/docs/embeddinggemma/fine-tuning-embeddinggemma-with-sentence-transformers)\n",
"* [Mood Palette Generator](https://huggingface.co/spaces/google/mood-palette), an interactive application using EmbeddingGemma"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "RAG_with_EmbeddingGemma.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}