, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.\n",
"\n",
"\n",
"{context}\n",
"\n",
"\n",
"\n",
"{question}\n",
"\n",
"\"\"\"\n",
"def create_conversation(sample):\n",
" return {\n",
" \"messages\": [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt.format(question=sample[\"sql_prompt\"], context=sample[\"sql_context\"])},\n",
" {\"role\": \"assistant\", \"content\": sample[\"sql\"]}\n",
" ]\n",
" }\n",
"\n",
"# Load dataset from the hub\n",
"dataset = load_dataset(\"philschmid/gretel-synthetic-text-to-sql\", split=\"train\")\n",
"dataset = dataset.shuffle().select(range(12500))\n",
"\n",
"# Convert dataset to OAI messages\n",
"dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)\n",
"# split dataset into 80% training samples and 20% test samples\n",
"dataset = dataset.train_test_split(test_size=0.2)\n",
"\n",
"# Print formatted user prompt\n",
"for item in dataset[\"train\"][0][\"messages\"]:\n",
" print(item)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c0eb2e06"
},
"source": [
"## Fine-tune Gemma using TRL and the SFTTrainer\n",
"\n",
"You are now ready to fine-tune your model. Hugging Face TRL [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) makes it straightforward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additional quality of life features, including:\n",
"\n",
"* Dataset formatting, including conversational and instruction formats\n",
"* Training on completions only, ignoring prompts\n",
"* Packing datasets for more efficient training\n",
"* Parameter-efficient fine-tuning (PEFT) support including QloRA\n",
"* Preparing the model and tokenizer for conversational fine-tuning (such as adding special tokens)\n",
"\n",
"The following code loads the Gemma model and tokenizer from Hugging Face and initializes the quantization configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "18069ed2"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0b17e7e80e884df59a0bea8b6f6802e9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f5cfbb54cfec4e7d93ed2eb0d5b2e62a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/10.2G [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a5f8ae73ccd3478985fbc37e95b89de8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading weights: 0%| | 0/2011 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d9a1c13e560c4790b626ab3fd045e1b0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"generation_config.json: 0%| | 0.00/181 [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb68212e51dc480d99a66d131838858e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b4a245124cc74c4db7b6ad73a1b65f33",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a464b2885d6649b586c73e74fcca0f07",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0%| | 0.00/32.2M [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a6525a790f4440ff989d5c815dd94da7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"chat_template.jinja: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForImageTextToText, BitsAndBytesConfig\n",
"\n",
"# Hugging Face model id\n",
"model_id = \"google/gemma-4-E2B\" # @param [\"google/gemma-4-E2B\",\"google/gemma-4-E4B\"] {\"allow-input\":true}\n",
"\n",
"# Check if GPU benefits from bfloat16\n",
"if torch.cuda.get_device_capability()[0] >= 8:\n",
" torch_dtype = torch.bfloat16\n",
"else:\n",
" torch_dtype = torch.float16\n",
"\n",
"# Define model init arguments\n",
"model_kwargs = dict(\n",
" dtype=torch_dtype,\n",
" device_map=\"auto\", # Let torch decide how to load the model\n",
")\n",
"\n",
"# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage\n",
"model_kwargs[\"quantization_config\"] = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_quant_type='nf4',\n",
" bnb_4bit_compute_dtype=model_kwargs['dtype'],\n",
" bnb_4bit_quant_storage=model_kwargs['dtype'],\n",
")\n",
"\n",
"# Load model and tokenizer\n",
"model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)\n",
"tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-4-E2B-it\") # Load the Instruction Tokenizer to use the official Gemma template"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "37ec1d1b"
},
"source": [
"The `SFTTrainer` supports a built-in integration with `peft`, which makes it straightforward to efficiently tune LLMs using QLoRA. You only need to create a `LoraConfig` and provide it to the trainer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ed00e846"
},
"outputs": [],
"source": [
"from peft import LoraConfig\n",
"\n",
"peft_config = LoraConfig(\n",
" lora_alpha=16,\n",
" lora_dropout=0.05,\n",
" r=16,\n",
" bias=\"none\",\n",
" target_modules=\"all-linear\",\n",
" task_type=\"CAUSAL_LM\",\n",
" modules_to_save=[\"lm_head\", \"embed_tokens\"], # make sure to save the lm_head and embed_tokens as you train the special tokens\n",
" ensure_weight_tying=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bbd9fc1b"
},
"source": [
"Before you can start your training, you need to define the hyperparameter you want to use in a `SFTConfig` instance."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "989be3c1"
},
"outputs": [],
"source": [
"import torch\n",
"from trl import SFTConfig\n",
"\n",
"args = SFTConfig(\n",
" output_dir=\"gemma-text-to-sql\", # directory to save and repository id\n",
" max_length=512, # max length for model and packing of the dataset\n",
" num_train_epochs=3, # number of training epochs\n",
" per_device_train_batch_size=1, # batch size per device during training\n",
" optim=\"adamw_torch_fused\", # use fused adamw optimizer\n",
" logging_steps=10, # log every 10 steps\n",
" save_strategy=\"epoch\", # save checkpoint every epoch\n",
" eval_strategy=\"epoch\", # evaluate checkpoint every epoch\n",
" learning_rate=5e-5, # learning rate\n",
" fp16=True if model.dtype == torch.float16 else False, # use float16 precision\n",
" bf16=True if model.dtype == torch.bfloat16 else False, # use bfloat16 precision\n",
" max_grad_norm=0.3, # max gradient norm based on QLoRA paper\n",
" lr_scheduler_type=\"constant\", # use constant learning rate scheduler\n",
" push_to_hub=True, # push model to hub\n",
" report_to=\"tensorboard\", # report metrics to tensorboard\n",
" dataset_kwargs={\n",
" \"add_special_tokens\": False, # Template with special tokens\n",
" \"append_concat_token\": True, # Add EOS token as separator token between examples\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dd88e798"
},
"source": [
"You now have every building block you need to create your `SFTTrainer` to start the training of your model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ade95df7"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9061644033864e22a5cd8905051b6637",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Tokenizing train dataset: 0%| | 0/10000 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f63804866860487cb9135f5729d76f01",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Tokenizing eval dataset: 0%| | 0/2500 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from trl import SFTTrainer\n",
"\n",
"# Create Trainer object\n",
"trainer = SFTTrainer(\n",
" model=model,\n",
" args=args,\n",
" train_dataset=dataset[\"train\"],\n",
" eval_dataset=dataset[\"test\"],\n",
" peft_config=peft_config,\n",
" processing_class=tokenizer,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fad61a6a"
},
"source": [
"Start training by calling the `train()` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "995e7e38"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [1875/1875 28:32, Epoch 3/3]\n",
"
\n",
" \n",
" \n",
" \n",
" | Epoch | \n",
" Training Loss | \n",
" Validation Loss | \n",
"
\n",
" \n",
" \n",
" \n",
" | 1 | \n",
" 0.536652 | \n",
" 0.530056 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.430735 | \n",
" 0.464053 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.386358 | \n",
" 0.443147 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Start training, the model will be automatically saved to the Hub and the output directory\n",
"trainer.train()\n",
"\n",
"# Save the final model again to the Hugging Face Hub\n",
"trainer.save_model()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b47b9733"
},
"source": [
"Before you can test your model, make sure to free the memory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "40a32ed7"
},
"outputs": [],
"source": [
"# free the memory again\n",
"del model\n",
"del trainer\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "862e9728"
},
"source": [
"When using QLoRA, you only train adapters and not the full model. This means when saving the model during training you only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with serving stacks like vLLM or TGI, you can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the `save_pretrained` method. This saves a default model, which can be used for inference.\n",
"\n",
"Note: It requires more than 30GB of CPU Memory when you want to merge the adapter into the model. You can skip this and continue with Test Model Inference."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "761e324b"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b58cae40ed3d40d89be8b4065548a69d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading weights: 0%| | 0/2011 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2d9db55b847041a5a3b446001239202a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Writing model shards: 0%| | 0/5 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"('merged_model/tokenizer_config.json',\n",
" 'merged_model/chat_template.jinja',\n",
" 'merged_model/tokenizer.json')"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from peft import PeftModel\n",
"\n",
"# Load Model base model\n",
"model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)\n",
"\n",
"# Merge LoRA and base model and save\n",
"peft_model = PeftModel.from_pretrained(model, args.output_dir)\n",
"merged_model = peft_model.merge_and_unload()\n",
"merged_model.save_pretrained(\"merged_model\", safe_serialization=True, max_shard_size=\"2GB\")\n",
"\n",
"processor = AutoTokenizer.from_pretrained(args.output_dir)\n",
"processor.save_pretrained(\"merged_model\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bf86e31d"
},
"source": [
"## Test Model Inference and generate SQL queries\n",
"\n",
"After the training is done, you'll want to evaluate and test your model. You can load different samples from the test dataset and evaluate the model on those samples.\n",
"\n",
"Note: Evaluating generative AI models is not a trivial task since one input can have multiple correct outputs. This guide only focuses on manual evaluation and vibe checks."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aab1c5c5"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "224c4db7e94445d9adb369eeac3c0bd2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading weights: 0%| | 0/2012 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The tied weights mapping and config for this model specifies to tie model.language_model.embed_tokens.weight to lm_head.weight, but both are present in the checkpoints with different values, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning.\n"
]
}
],
"source": [
"import torch\n",
"from transformers import pipeline\n",
"\n",
"model_id = \"merged_model\"\n",
"\n",
"# Load Model with PEFT adapter\n",
"model = AutoModelForImageTextToText.from_pretrained(\n",
" model_id,\n",
" device_map=\"auto\",\n",
" dtype=\"auto\",\n",
")\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3dccb57c"
},
"source": [
"Let's load a random sample from the test dataset and generate a SQL command."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1fd887f4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|turn>system\n",
"You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\n",
"<|turn>user\n",
"Given the and the , generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.\n",
"\n",
"\n",
"CREATE TABLE broadband_plans (plan_id INT, plan_name VARCHAR(255), download_speed INT, upload_speed INT, price DECIMAL(5,2));\n",
"\n",
"\n",
"\n",
"Delete a broadband plan from the 'broadband_plans' table\n",
"\n",
"<|turn>model\n",
"\n",
"Context:\n",
" CREATE TABLE broadband_plans (plan_id INT, plan_name VARCHAR(255), download_speed INT, upload_speed INT, price DECIMAL(5,2));\n",
"Query:\n",
" Delete a broadband plan from the 'broadband_plans' table\n",
"Original Answer:\n",
"DELETE FROM broadband_plans WHERE plan_id = 3001;\n",
"Generated Answer:\n",
"DELETE FROM broadband_plans\n",
"WHERE plan_name = 'Basic';\n"
]
}
],
"source": [
"from random import randint\n",
"import re\n",
"from transformers import pipeline, GenerationConfig\n",
"\n",
"config = GenerationConfig.from_pretrained(model_id)\n",
"config.max_new_tokens = 256\n",
"\n",
"# Load the model and tokenizer into the pipeline\n",
"pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)\n",
"\n",
"# Load a random sample from the test dataset\n",
"rand_idx = randint(0, len(dataset[\"test\"]))\n",
"test_sample = dataset[\"test\"][rand_idx]\n",
"\n",
"# Convert as test example into a prompt with the Gemma template\n",
"prompt = pipe.tokenizer.apply_chat_template(test_sample[\"messages\"][:2], tokenize=False, add_generation_prompt=True)\n",
"print(prompt)\n",
"\n",
"# Generate our SQL query.\n",
"outputs = pipe(prompt, generation_config=config)\n",
"\n",
"# Extract the user query and original answer\n",
"print(f\"Context:\\n\", re.search(r'\\n(.*?)\\n', test_sample['messages'][1]['content'], re.DOTALL).group(1).strip())\n",
"print(f\"Query:\\n\", re.search(r'\\n(.*?)\\n', test_sample['messages'][1]['content'], re.DOTALL).group(1).strip())\n",
"print(f\"Original Answer:\\n{test_sample['messages'][2]['content']}\")\n",
"print(f\"Generated Answer:\\n{outputs[0]['generated_text'][len(prompt):].strip()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6f8ff452"
},
"source": [
"## Summary and next steps\n",
"\n",
"This tutorial covered how to fine-tune a Gemma model using TRL and QLoRA. Check out the following docs next:\n",
"\n",
"* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).\n",
"* Learn how to [fine-tune Gemma for vision tasks using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora).\n",
"* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).\n",
"* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n",
"* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "huggingface_text_finetune_qlora.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}