{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "pn1797sn9Jb_" }, "source": [ "##### Copyright 2025 Google LLC." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "uivh5PY69ISg" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "O83CmJ2j9L3n" }, "source": [ "# Fine-Tune Gemma for Vision Tasks using Hugging Face Transformers and QLoRA" ] }, { "cell_type": "markdown", "metadata": { "id": "f9673bd6" }, "source": [ "\n", " \n", " \n", " \n", " \n", " \n", "
\n", " View on ai.google.dev\n", " \n", " Run in Google Colab\n", " \n", " Run in Kaggle\n", " \n", " Open in Vertex AI\n", " \n", " View source on GitHub\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "e624ec07" }, "source": [ "This guide walks you through how to fine-tune Gemma on a custom image and text dataset for a vision task (generating product descriptions) using Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) and [TRL](https://huggingface.co/docs/trl/index). You will learn:\n", "\n", "- What is Quantized Low-Rank Adaptation (QLoRA)\n", "- Setup development environment\n", "- Create and prepare the fine-tuning dataset\n", "- Fine-tune Gemma using TRL and the SFTTrainer\n", "- Test Model Inference and generate product descriptions from images and text.\n", "\n", "Note: This guide requires a GPU which support bfloat16 data type such as NVIDIA L4 or NVIDIA A100 and more than 16GB of memory.\n", "\n", "## What is Quantized Low-Rank Adaptation (QLoRA)\n", "\n", "This guide demonstrates the use of [Quantized Low-Rank Adaptation (QLoRA)](https://arxiv.org/abs/2305.14314), which emerged as a popular method to efficiently fine-tune LLMs as it reduces computational resource requirements while maintaining high performance. In QloRA, the pretrained model is quantized to 4-bit and the weights are frozen. Then trainable adapter layers (LoRA) are attached and only the adapter layers are trained. Afterwards, the adapter weights can be merged with the base model or kept as a separate adapter.\n", "\n", "## Setup development environment\n", "\n", "The first step is to install Hugging Face Libraries, including TRL, and datasets to fine-tune open model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ba51aa79" }, "outputs": [], "source": [ "# Install Pytorch & other libraries\n", "%pip install torch tensorboard torchvision\n", "\n", "# Install Transformers\n", "%pip install transformers\n", "\n", "# Install Hugging Face libraries\n", "%pip install datasets accelerate evaluate bitsandbytes trl peft protobuf pillow sentencepiece\n", "\n", "# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100\n", "#%pip install flash-attn" ] }, { "cell_type": "markdown", "metadata": { "id": "7ef3d54b" }, "source": [ "_Note: If you are using a GPU with Ampere architecture (such as NVIDIA L4) or newer, you can use Flash attention. Flash Attention is a method that significantly speeds computations up and reduces memory usage from quadratic to linear in sequence length, leading to acelerating training up to 3x. Learn more at [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main)._\n", "\n", "You need a valid Hugging Face Token to publish your model. If you are running inside a Google Colab, you can securely use your Hugging Face Token using the Colab secrets otherwise you can set the token as directly in the `login` method. Make sure your token has write access too, as you push your model to the Hub during training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "b6d79c93" }, "outputs": [], "source": [ "# Login into Hugging Face Hub\n", "from huggingface_hub import login\n", "login()" ] }, { "cell_type": "markdown", "metadata": { "id": "42c60525" }, "source": [ "## Create and prepare the fine-tuning dataset\n", "\n", "When fine-tuning LLMs, it is important to know your use case and the task you want to solve. This helps you create a dataset to fine-tune your model. If you haven't defined your use case yet, you might want to go back to the drawing board.\n", "\n", "As an example, this guide focuses on the following use case:\n", "\n", "- Fine-tuning a Gemma model to generate concise, SEO-optimized product descriptions for an ecommerce platform, specifically tailored for mobile search.\n", "\n", "This guide uses the [philschmid/amazon-product-descriptions-vlm](https://huggingface.co/datasets/philschmid/amazon-product-descriptions-vlm) dataset, a dataset of Amazon product descriptions, including product images and categories.\n", "\n", "Hugging Face TRL supports multimodal conversations. The important piece is the \"image\" role, which tells the processing class that it should load the image. The structure should follow:\n", "\n", "```json\n", "{\"messages\": [{\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\":\"You are...\"}]}, {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}, {\"type\": \"image\"}]}, {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}]}]}\n", "{\"messages\": [{\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\":\"You are...\"}]}, {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}, {\"type\": \"image\"}]}, {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}]}]}\n", "{\"messages\": [{\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\":\"You are...\"}]}, {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}, {\"type\": \"image\"}]}, {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": \"...\"}]}]}\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "c4ecf6db" }, "source": [ "You can now use the Hugging Face Datasets library to load the dataset and create a prompt template to combine the image, product name, and category, and add a system message. The dataset includes images as`Pil.Image` objects." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "40c3a2cf" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8d1259be3dfa4b1e899c97026276ee41", "version_major": 2, "version_minor": 0 }, "text/plain": [ "README.md: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a5554c0595144c949b578eb1cbdfd0fd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "data/train-00000-of-00001.parquet: 0%| | 0.00/47.6M [00:00 and and image.\\nOnly return description. The description should be SEO optimized and for a better mobile search experience.\\n\\n\\nRazor Agitator BMX/Freestyle Bike, 20-Inch\\n\\n\\n\\nSports & Outdoors | Outdoor Recreation | Cycling | Kids' Bikes & Accessories | Kids' Bikes\\n\\n\"}, {'type': 'image', 'image': }]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': 'Conquer the streets with the Razor Agitator BMX Bike! This 20-inch freestyle bike is built for young riders ready to take on any challenge. Durable frame, responsive handling – perfect for tricks and cruising. Get yours today!'}]}]\n" ] } ], "source": [ "from datasets import load_dataset\n", "from PIL import Image\n", "\n", "# System message for the assistant\n", "system_message = \"You are an expert product description writer for Amazon.\"\n", "\n", "# User prompt that combines the user query and the schema\n", "user_prompt = \"\"\"Create a Short Product description based on the provided and and image.\n", "Only return description. The description should be SEO optimized and for a better mobile search experience.\n", "\n", "\n", "{product}\n", "\n", "\n", "\n", "{category}\n", "\n", "\"\"\"\n", "\n", "# Convert dataset to OAI messages\n", "def format_data(sample):\n", " return {\n", " \"messages\": [\n", " {\n", " \"role\": \"system\",\n", " #\"content\": [{\"type\": \"text\", \"text\": system_message}],\n", " \"content\": system_message,\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\n", " \"type\": \"text\",\n", " \"text\": user_prompt.format(\n", " product=sample[\"Product Name\"],\n", " category=sample[\"Category\"],\n", " ),\n", " },\n", " {\n", " \"type\": \"image\",\n", " \"image\": sample[\"image\"],\n", " },\n", " ],\n", " },\n", " {\n", " \"role\": \"assistant\",\n", " \"content\": [{\"type\": \"text\", \"text\": sample[\"description\"]}],\n", " },\n", " ],\n", " }\n", "\n", "def process_vision_info(messages: list[dict]) -> list[Image.Image]:\n", " image_inputs = []\n", " # Iterate through each conversation\n", " for msg in messages:\n", " # Get content (ensure it's a list)\n", " content = msg.get(\"content\", [])\n", " if not isinstance(content, list):\n", " content = [content]\n", "\n", " # Check each content element for images\n", " for element in content:\n", " if isinstance(element, dict) and (\n", " \"image\" in element or element.get(\"type\") == \"image\"\n", " ):\n", " # Get the image and convert to RGB\n", " if \"image\" in element:\n", " image = element[\"image\"]\n", " else:\n", " image = element\n", " image_inputs.append(image.convert(\"RGB\"))\n", " return image_inputs\n", "\n", "# Load dataset from the hub\n", "dataset = load_dataset(\"philschmid/amazon-product-descriptions-vlm\", split=\"train\")\n", "dataset = dataset.train_test_split(test_size=0.1)\n", "\n", "# Convert dataset to OAI messages\n", "# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes\n", "dataset_train = [format_data(sample) for sample in dataset[\"train\"]]\n", "dataset_test = [format_data(sample) for sample in dataset[\"test\"]]\n", "\n", "print(dataset_train[345][\"messages\"])" ] }, { "cell_type": "markdown", "metadata": { "id": "c0eb2e06" }, "source": [ "## Fine-tune Gemma using TRL and the SFTTrainer\n", "\n", "You are now ready to fine-tune your model. Hugging Face TRL [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) makes it straightforward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additional quality of life features, including:\n", "\n", "* Dataset formatting, including conversational and instruction formats\n", "* Training on completions only, ignoring prompts\n", "* Packing datasets for more efficient training\n", "* Parameter-efficient fine-tuning (PEFT) support including QloRA\n", "* Preparing the model and tokenizer for conversational fine-tuning (such as adding special tokens)\n", "\n", "The following code loads the Gemma model and tokenizer from Hugging Face and initializes the quantization configuration.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "18069ed2" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "42e58727637d4495ad8c5f753c5bcd06", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b11ec04ab48043b9937cfa3822b4fa42", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors: 0%| | 0.00/10.2G [00:00\n", " \n", " \n", " [456/456 11:20, Epoch 3/3]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation Loss
11.3267101.441816
21.0427111.320613
30.7391791.458798

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Start training, the model will be automatically saved to the Hub and the output directory\n", "trainer.train()\n", "\n", "# Save the final model again to the Hugging Face Hub\n", "trainer.save_model()" ] }, { "cell_type": "markdown", "metadata": { "id": "b47b9733" }, "source": [ "Before you can test your model, make sure to free the memory." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "40a32ed7" }, "outputs": [], "source": [ "# free the memory again\n", "del model\n", "del trainer\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "metadata": { "id": "862e9728" }, "source": [ "When using QLoRA, you only train adapters and not the full model. This means when saving the model during training you only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with serving stacks like vLLM or TGI, you can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the `save_pretrained` method. This saves a default model, which can be used for inference.\n", "\n", "Note: It requires more than 30GB of CPU Memory when you want to merge the adapter into the model. You can skip this and continue with Test Model Inference.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "761e324b" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "20d63c526a854f2a880882c246ac3b3d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading weights: 0%| | 0/2011 [00:00<|turn>system\n", "You are an expert product description writer for Amazon.\n", "<|turn>user\n", "\n", "\n", "<|image|>\n", "\n", "Create a Short Product description based on the provided and and image.\n", "Only return description. The description should be SEO optimized and for a better mobile search experience.\n", "\n", "\n", "Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur\n", "\n", "\n", "\n", "Toys & Games | Toy Figures & Playsets | Action Figures\n", "\n", "<|turn>model\n", "\n", "MODEL OUTPUT>> \n", "\n", "Enhance your collection with the Marvel Avengers - Avengers Assemble Ultron-Comforter Set! This soft and cuddly blanket and pillowcase feature everyone's favorite Avengers, Iron Man, and his loyal companion War Machine. Officially licensed by Marvel. Bring home the heroic team!\n" ] } ], "source": [ "import requests\n", "from PIL import Image\n", "\n", "# Test sample with Product Name, Category and Image\n", "sample = {\n", " \"product_name\": \"Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur\",\n", " \"category\": \"Toys & Games | Toy Figures & Playsets | Action Figures\",\n", " \"image\": Image.open(requests.get(\"https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg\", stream=True).raw).convert(\"RGB\")\n", "}\n", "\n", "def generate_description(sample, model, processor):\n", " # Convert sample into messages and then apply the chat template\n", " messages = [\n", " {\"role\": \"system\", \"content\": system_message},\n", " {\"role\": \"user\", \"content\": [\n", " {\"type\": \"image\",\"image\": sample[\"image\"]},\n", " {\"type\": \"text\", \"text\": user_prompt.format(product=sample[\"product_name\"], category=sample[\"category\"])},\n", " ]},\n", " ]\n", " text = processor.apply_chat_template(\n", " messages, tokenize=False, add_generation_prompt=True\n", " )\n", " print(text)\n", " # Process the image and text\n", " image_inputs = process_vision_info(messages)\n", " # Tokenize the text and process the images\n", " inputs = processor(\n", " text=[text],\n", " images=image_inputs,\n", " padding=True,\n", " return_tensors=\"pt\",\n", " )\n", " # Move the inputs to the device\n", " inputs = inputs.to(model.device)\n", "\n", " # Generate the output\n", " stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids(\"\")]\n", " generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)\n", " # Trim the generation and decode the output to text\n", " generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]\n", " output_text = processor.batch_decode(\n", " generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n", " )\n", " return output_text[0]\n", "\n", "# generate the description\n", "description = generate_description(sample, model, processor)\n", "print(\"MODEL OUTPUT>> \\n\")\n", "print(description)" ] }, { "cell_type": "markdown", "metadata": { "id": "6f8ff452" }, "source": [ "## Summary and next steps\n", "\n", "This tutorial covered how to fine-tune a Gemma model for vision tasks using TRL and QLoRA, specifically for generating product descriptions. Check out the following docs next:\n", "\n", "* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).\n", "* Learn how to [fine-tune Gemma for text tasks using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_text_finetune_qlora).\n", "* Learn how to [full model fine-tune using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune).\n", "* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).\n", "* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n", "* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)." ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "huggingface_vision_finetune_qlora.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }