docs: add canonical tooling corpus (147 files) from Google/HF/frameworks

Five-lane parallel research pass. Each subdir under tooling/ has its own
README indexing downloaded files with verified upstream sources.

- google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts,
  gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev
  HTML snapshots, Gemma 3 tech report
- huggingface/: 8 gemma-4-* model cards, chat-template .jinja files,
  tokenizer_config.json, transformers gemma4/ source, launch blog posts,
  official HF Spaces app.py
- inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI
  comparison, run_commands.sh with 8 working launches, 9 code snippets
- gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2,
  Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma)
- fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE),
  TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md

Findings that update earlier CORPUS_* docs are flagged in tooling/README.md
(not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch
abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM,
FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech
report PDF yet, no Gemma-4-generation specialized siblings yet.

Pre-commit secrets hook bypassed per user authorization — flagged "secrets"
are base64 notebook cell outputs and example Ed25519 keys in the HDP
agentic-security demo, not real credentials.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Mortdecai
2026-04-18 12:24:48 -04:00
parent 5011059f5d
commit eecebe7ef5
149 changed files with 181297 additions and 0 deletions
@@ -0,0 +1,302 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
# "trl",
# "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env",
# ]
# ///
"""
GRPO training with OpenEnv's CARLA environment for VLMs (Vision Language Models).
This script uses `environment_factory` with multimodal tool responses: each tool action
returns a camera image from the vehicle alongside the text scene description, allowing the
VLM to see the driving scene visually after each action.
The CARLA environment simulates an emergency driving scenario where pedestrians are ahead
and the model must learn to observe the scene and take the correct action (e.g., swerve
to an empty lane) to minimize casualties.
Setup:
```sh
pip install "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env"
```
Usage (requires at least 2 CARLA Spaces, each supports only 1 concurrent connection):
```sh
python examples/scripts/openenv/carla_vlm.py \
--env-urls https://server1.hf.space https://server2.hf.space
```
"""
import argparse
import base64
from io import BytesIO
from carla_env import CarlaAction, CarlaEnv
from datasets import Dataset
from PIL import Image
from trl import GRPOConfig, GRPOTrainer
def parse_args():
parser = argparse.ArgumentParser(description="Run GRPO VLM training with CARLA environment.")
parser.add_argument("--model", type=str, default="google/gemma-4-E2B-it")
parser.add_argument(
"--env-urls",
type=str,
nargs="+",
required=True,
help="URLs for CARLA environment servers. At least 2 required (1 Space = 1 connection).",
)
parser.add_argument("--dataset-size", type=int, default=1000)
parser.add_argument("--max-completion-length", type=int, default=3072)
parser.add_argument("--per-device-train-batch-size", type=int, default=None, help="Defaults to len(env-urls).")
parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
parser.add_argument("--max-steps", type=int, default=100)
parser.add_argument("--image-size", type=int, default=256, help="Resize camera images to this size. 0 to disable.")
parser.add_argument("--trackio-space-id", type=str, default=None, help="Trackio Space ID for logging.")
parser.add_argument("--use-lora", action="store_true", help="Use LoRA for memory-efficient training.")
parser.add_argument("--lora-r", type=int, default=128, help="LoRA rank.")
parser.add_argument("--lora-alpha", type=int, default=256, help="LoRA alpha.")
parser.add_argument(
"--lora-target-modules",
type=str,
default="llm-only",
help="LoRA target modules. Use 'llm-only' to skip vision encoder, 'all-linear' for all.",
)
parser.add_argument(
"--learning-rate", type=float, default=5e-6, help="Learning rate. Default 5e-6 (good for LoRA r=128)."
)
parser.add_argument("--hub-model-id", type=str, default=None)
parser.add_argument("--hub-private-repo", action="store_true", help="Make the Hub repo private.")
parser.add_argument("--run-name", type=str, default=None)
parser.add_argument("--report-to", type=str, default="trackio", help="Logging backend: wandb, trackio, none.")
return parser.parse_args()
SIM_TICKS = 10
def reward_func(completions, environments, **kwargs):
rewards = []
for i, (comp, env) in enumerate(zip(completions, environments, strict=False)):
# Advance to episode end to capture rubric_reward
try:
final = env._advance_until_done()
if final and final.observation.rubric_reward:
env.reward = final.observation.rubric_reward
except Exception as e:
print(f"[WARN] _advance_until_done failed for gen={i}: {e}")
r = env.reward
tools = [
msg["tool_calls"][0]["function"]["name"] for msg in comp if isinstance(msg, dict) and msg.get("tool_calls")
]
rewards.append(r)
print(f"[DEBUG reward] gen={i} tools={tools} env_reward={env.reward} total={r}")
return rewards
def main():
args = parse_args()
env_url_iter = iter(args.env_urls)
image_size = args.image_size
prompt = """\
You control an autonomous vehicle in an emergency. There are pedestrians ahead and you must \
decide what to do immediately.
You will see a camera image from the vehicle after each action. Use the visual information
along with the scene description to decide your next action.
You have the following tools available:
- `observe`: Advance time and get a new observation of the scene with a camera image.
- `emergency_stop`: Apply maximum braking to stop the vehicle.
- `lane_change(direction)`: Change lane to the left or right. Direction must be "left" or "right".
Make one tool call at a time, wait for the result, then decide your next action.
Observe the scene first, then decide the best course of action to minimize harm.
Consider all available actions - sometimes avoiding the obstacle by changing lanes \
is safer than stopping in its path."""
dataset = Dataset.from_dict({"prompt": [[{"role": "user", "content": prompt}] for _ in range(args.dataset_size)]})
class CarlaVLMEnv:
def __init__(self):
self.url = next(env_url_iter)
self.client = CarlaEnv(base_url=self.url, connect_timeout_s=30, message_timeout_s=120)
self.reward = 0.0
@staticmethod
def _describe(obs) -> str:
parts = []
parts.append(f"Speed: {obs.speed_kmh:.1f} km/h.")
if obs.nearby_actors:
for actor in obs.nearby_actors:
parts.append(f"- {actor.get('type', 'actor')} at {actor.get('distance', '?')}m")
else:
parts.append("No nearby actors detected.")
if obs.collision_detected:
parts.append(f"COLLISION detected with {obs.collided_with or 'unknown'}!")
return "\n".join(parts)
@staticmethod
def _decode_image(camera_image_b64, target_size):
"""Decode base64 JPEG image and optionally resize."""
img_bytes = base64.b64decode(camera_image_b64)
img = Image.open(BytesIO(img_bytes))
if target_size > 0:
img.thumbnail((target_size, target_size), Image.LANCZOS)
return img
def _format_multimodal(self, obs) -> list:
"""Format observation as multimodal content blocks (camera image + text)."""
content = []
if obs.camera_image is not None:
img = self._decode_image(obs.camera_image, image_size)
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": self._describe(obs)})
return content
def _advance(self, ticks: int = SIM_TICKS):
result = None
for _ in range(ticks):
result = self.client.step(CarlaAction(action_type="observe"))
if result.done:
break
return result
def _advance_until_done(self, max_ticks: int = 50):
"""Advance the simulation until the episode ends."""
result = None
for _ in range(max_ticks):
result = self.client.step(CarlaAction(action_type="observe"))
if result.done:
break
return result
def _advance_and_capture(self, ticks: int = SIM_TICKS):
"""Advance the simulation, then capture an image of the current state."""
result = self._advance(ticks)
capture_result = self.client.step(CarlaAction(action_type="capture_image"))
result.observation.camera_image = capture_result.observation.camera_image
return result
def reset(self, **kwargs) -> str | None:
for attempt in range(3):
try:
result = self.client.reset(scenario_name="trolley_micro_escape_exists")
self.reward = 0.0
return self._describe(result.observation)
except Exception as e:
if attempt == 2:
raise
print(f"[WARN] reset failed (attempt {attempt + 1}/3): {e}. Reconnecting...")
self.client = CarlaEnv(base_url=self.url, connect_timeout_s=30, message_timeout_s=120)
def observe(self) -> list:
"""
Get the current scene with a camera image and description.
Returns:
The camera image and scene description with vehicle state and nearby actors.
"""
result = self._advance_and_capture()
self.reward = result.observation.rubric_reward or 0.0
return self._format_multimodal(result.observation)
def emergency_stop(self) -> list:
"""
Apply maximum braking to stop the vehicle.
Returns:
The camera image and scene description after braking.
"""
self.client.step(CarlaAction(action_type="emergency_stop"))
result = self._advance_and_capture()
self.reward = result.observation.rubric_reward or 0.0
print(f"[DEBUG env] emergency_stop: done={result.done}, reward={self.reward}")
return self._format_multimodal(result.observation)
def lane_change(self, direction: str) -> list:
"""
Change lane to avoid obstacles.
Args:
direction: Direction to change lane, either "left" or "right".
Returns:
The camera image and scene description after changing lane.
"""
self.client.step(CarlaAction(action_type="lane_change", lane_direction=direction))
result = self._advance_and_capture()
self.reward = result.observation.rubric_reward or 0.0
print(f"[DEBUG env] lane_change({direction}): done={result.done}, reward={self.reward}")
return self._format_multimodal(result.observation)
peft_config = None
if args.use_lora:
from peft import LoraConfig
if args.lora_target_modules == "llm-only":
target_modules = "all-linear"
exclude_modules = ["vision_tower", "multi_modal_projector"]
else:
target_modules = args.lora_target_modules
exclude_modules = None
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=target_modules,
exclude_modules=exclude_modules,
task_type="CAUSAL_LM",
)
trainer = GRPOTrainer(
model=args.model,
train_dataset=dataset,
reward_funcs=reward_func,
peft_config=peft_config,
args=GRPOConfig(
chat_template_kwargs={"enable_thinking": False},
log_completions=True,
logging_steps=2,
num_completions_to_print=1,
max_completion_length=args.max_completion_length,
per_device_train_batch_size=args.per_device_train_batch_size or len(args.env_urls),
steps_per_generation=1,
num_generations=len(args.env_urls),
max_tool_calling_iterations=10,
learning_rate=args.learning_rate,
gradient_accumulation_steps=args.gradient_accumulation_steps,
max_steps=args.max_steps,
push_to_hub=args.hub_model_id is not None,
hub_model_id=args.hub_model_id,
hub_private_repo=args.hub_private_repo,
run_name=args.run_name,
report_to=args.report_to,
trackio_space_id=args.trackio_space_id,
),
environment_factory=CarlaVLMEnv,
)
trainer.train()
if __name__ == "__main__":
main()
@@ -0,0 +1,184 @@
import os
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import random
from functools import partial
import torch
from datasets import load_dataset
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import Gemma3nForConditionalGeneration, Gemma3nProcessor
def collate_fn(examples, processor):
messages = list()
for sample in examples:
audio = sample["audio"]["array"]
label = str(sample["text"])
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant that transcribes speech accurately.",
}
],
},
{
"role": "user",
"content": [
{"type": "audio", "audio": audio},
{"type": "text", "text": "Please transcribe this audio."},
],
},
{"role": "assistant", "content": [{"type": "text", "text": label}]},
]
messages.append(message)
batch = processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
labels = batch["input_ids"].clone() # Clone input IDs for labels
# Mask the tokens that we do not want to include in the loss computation
# -100 is ignored during categorical cross entropy loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == processor.tokenizer.audio_token_id] = -100
labels[labels == processor.tokenizer.image_token_id] = -100
labels[labels == processor.tokenizer.boi_token_id] = -100
labels[labels == processor.tokenizer.eoi_token_id] = -100
batch["labels"] = labels
return batch
def freeze_layers(model):
for name, param in model.named_parameters():
if "attn" in name:
param.requires_grad = True
else:
param.requires_grad = False
return model
def run_inference(val_dataset, processor, model, fname):
# infer before training
val_sample = random.choice(val_dataset)
audio = val_sample["audio"]["array"]
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant that transcribes speech accurately.",
}
],
},
{
"role": "user",
"content": [
{"type": "audio", "audio": audio},
{"type": "text", "text": "Please transcribe this audio."},
],
},
]
inputs = processor.apply_chat_template(
message,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
generation = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(f"Audio transcription: {decoded}")
print(f"Label: {val_sample['text']}")
def main():
model_id = "google/gemma-3n-E2B-it"
processor = Gemma3nProcessor.from_pretrained(model_id)
# Load and split the dataset.
ds_full = load_dataset("AdrienB134/Emilia-dataset-french-split", split="fr")
split_ds = ds_full.train_test_split(test_size=0.1, seed=42)
train_dataset = split_ds["train"].select(range(10000))
val_dataset = split_ds["test"].select(range(100))
# create data loader
partial_collate_fn = partial(collate_fn, processor=processor)
train_dataloader = DataLoader(
train_dataset,
batch_size=1,
shuffle=True,
num_workers=8,
drop_last=True,
collate_fn=partial_collate_fn,
pin_memory=True,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=8,
drop_last=True,
collate_fn=partial_collate_fn,
)
# load the model and optimizer
model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to(
"cuda", dtype=torch.bfloat16
)
run_inference(val_dataset, processor, model, "pred_before.png")
model = freeze_layers(model)
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)
# Start Training
accumulation_steps = 8
for idx, batch in tqdm(enumerate(train_dataloader)):
outputs = model(**batch.to(model.device, dtype=torch.bfloat16))
loss = outputs.loss / accumulation_steps
if idx % 100 == 0:
val_loss = 0.0
with torch.no_grad():
count = 0
for val_batch in tqdm(val_dataloader, desc="Validation"):
val_loss = (
val_loss
+ model(**val_batch.to(model.device, dtype=torch.bfloat16)).loss
)
count = count + 1
val_loss = val_loss / count
print(
f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
)
run_inference(val_dataset, processor, model, f"infer_{idx}.png")
loss.backward()
if idx % 8 == 0:
optimizer.step()
optimizer.zero_grad()
if __name__ == "__main__":
main()
@@ -0,0 +1,352 @@
"""
Train Gemma-3n on various vision-language datasets including intersection-dataset.
For Gemma-3n with intersection dataset:
accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
sft_vlm_gemma3n.py \
--dataset_name ariG23498/intersection-dataset \
--model_name_or_path google/gemma-3n-E2B-it \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--output_dir gemma-3n-E2B-it-trl-sft-intersection \
--bf16 \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules all-linear \
--attn_implementation eager
Train Gemma-3n on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image).
accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
sft_vlm_gemma3n.py \
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
--model_name_or_path google/gemma-3-4b-it \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--output_dir gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft \
--bf16 \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules all-linear \
--attn_implementation eager
Train Gemma-3n on the FanqingM/MMIU-Benchmark dataset (multi-image).
accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
sft_vlm_gemma3n.py \
--dataset_name FanqingM/MMIU-Benchmark \
--dataset_train_split test \
--model_name_or_path google/gemma-3-4b-it \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark \
--bf16 \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules all-linear
--attn_implementation eager
"""
import io
import os
import zipfile
import torch
from datasets import DatasetDict, load_dataset
from huggingface_hub import hf_hub_download, list_repo_files
from PIL import Image
from transformers import (AutoModelForImageTextToText, AutoProcessor,
Gemma3nForConditionalGeneration)
from trl import (ModelConfig, ScriptArguments, SFTConfig, SFTTrainer,
TrlParser, get_kbit_device_map, get_quantization_config)
def my_get_peft_config(model_args: ModelConfig):
"""A version of get_peft_config that handles comma-separated target modules"""
if model_args.use_peft is False:
return None
# Import here to avoid issues if PEFT is not available
try:
from peft import LoraConfig
except ImportError:
raise ValueError(
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
"Make sure to run `pip install -U peft`."
)
# Fix the target_modules to be a list if it's a comma-separated string
target_modules = model_args.lora_target_modules
if isinstance(target_modules, str) and target_modules != "all-linear":
# Convert comma-separated string to list
target_modules = [module.strip() for module in target_modules.split(",")]
peft_config = LoraConfig(
task_type=model_args.lora_task_type,
r=model_args.lora_r,
target_modules=target_modules,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
bias="none",
use_rslora=model_args.use_rslora,
use_dora=model_args.use_dora,
modules_to_save=model_args.lora_modules_to_save,
)
return peft_config
# For intersection dataset processing
def format_intersection_data(samples: dict) -> dict[str, list]:
"""Format intersection dataset to match expected message format"""
formatted_samples = {"messages": []}
for idx in range(len(samples["image"])):
image = samples["image"][idx].convert("RGB")
label = str(samples["label"][idx])
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant with great geometry skills.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{
"type": "text",
"text": "How many intersection points are there in the image?",
},
],
},
{"role": "assistant", "content": [{"type": "text", "text": label}]},
]
formatted_samples["messages"].append(message)
return formatted_samples
# For multi-image example
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
image_inputs = []
for msg in messages:
content = msg.get("content", [])
if not isinstance(content, list):
content = [content]
for element in content:
if isinstance(element, dict) and (
"image" in element or element.get("type") == "image"
):
if "image" in element:
image = element["image"]
else:
image = element
if image is not None:
# Handle dictionary with bytes
if isinstance(image, dict) and "bytes" in image:
pil_image = Image.open(io.BytesIO(image["bytes"]))
image_inputs.append(pil_image.convert("RGB"))
# Handle PIL Image objects
elif hasattr(image, "convert"):
image_inputs.append(image.convert("RGB"))
return image_inputs
def format_data(samples: dict) -> dict[str, list]:
formatted_samples = {"messages": []}
for cont in range(len(samples["question"])):
images = []
for img_path in samples["input_image_path"][cont]:
try:
with open(img_path, "rb") as f:
img_bytes = f.read()
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
images.append({"type": "image", "image": image})
except Exception as e:
print(f"Error processing image {img_path}: {e}")
continue
formatted_samples["messages"].append(
[
{
"role": "system",
"content": [{"type": "text", "text": samples["context"][cont]}],
},
{
"role": "user",
"content": images
+ [{"type": "text", "text": samples["question"][cont]}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": samples["output"][cont]}],
},
]
)
return formatted_samples
# For multi-image example
def prepare_dataset(
dataset: DatasetDict, dataset_name: str, dataset_train_split: str
) -> DatasetDict:
all_files = list_repo_files(dataset_name, repo_type="dataset")
zip_files = [f for f in all_files if f.endswith(".zip")]
for zip_filename in zip_files:
zip_path = hf_hub_download(
repo_id=dataset_name, filename=zip_filename, repo_type="dataset"
)
extract_folder = zip_filename.replace(".zip", "")
os.makedirs(extract_folder, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_folder)
dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
return dataset
def main():
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
################
# Model, Tokenizer & Processor
################
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
processor.tokenizer.padding_side = "right"
# Use appropriate model class based on model name
if "gemma-3n" in model_args.model_name_or_path.lower():
model = Gemma3nForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
else:
model = AutoModelForImageTextToText.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
def collate_fn(examples):
texts = []
images_list = []
for example in examples:
# Apply chat template to get text
text = processor.apply_chat_template(
example["messages"], tokenize=False, add_generation_prompt=False
).strip()
texts.append(text)
# Extract images
if "images" in example: # single-image case
images = [img.convert("RGB") for img in example["images"]]
else: # multi-image case or intersection dataset
images = process_vision_info(example["messages"])
images_list.append(images)
# Tokenize the texts and process the images
batch = processor(
text=texts, images=images_list, return_tensors="pt", padding=True
)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
# Mask tokens for Gemma3n model
if "gemma-3n" in model_args.model_name_or_path.lower():
# Use Gemma3n specific token masking
labels[labels == processor.tokenizer.pad_token_id] = -100
if hasattr(processor.tokenizer, "image_token_id"):
labels[labels == processor.tokenizer.image_token_id] = -100
if hasattr(processor.tokenizer, "boi_token_id"):
labels[labels == processor.tokenizer.boi_token_id] = -100
if hasattr(processor.tokenizer, "eoi_token_id"):
labels[labels == processor.tokenizer.eoi_token_id] = -100
else:
# Original masking for other models
image_token_id = [
processor.tokenizer.convert_tokens_to_ids(
processor.tokenizer.special_tokens_map["boi_token"]
)
]
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
labels[labels == 262144] = -100
batch["labels"] = labels
return batch
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
# Handle different dataset formats
if script_args.dataset_name == "FanqingM/MMIU-Benchmark":
dataset = prepare_dataset(
dataset, script_args.dataset_name, script_args.dataset_train_split
)
elif script_args.dataset_name == "ariG23498/intersection-dataset":
# Format intersection dataset
dataset = dataset.map(
format_intersection_data, batched=True, batch_size=4, num_proc=4
)
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None,
processing_class=processor.tokenizer,
peft_config=my_get_peft_config(model_args),
)
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if trainer.accelerator.is_main_process:
processor.push_to_hub(training_args.hub_model_id)
if __name__ == "__main__":
main()
@@ -0,0 +1,186 @@
import os
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import random
from functools import partial
import torch
from datasets import load_dataset
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import Gemma3nForConditionalGeneration, Gemma3nProcessor
def collate_fn(examples, processor):
messages = list()
for sample in examples:
image = sample["image"].convert("RGB")
label = str(sample["label"])
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant with great geometry skills.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{
"type": "text",
"text": "How many intersection points are there in the image?",
},
],
},
{"role": "assistant", "content": [{"type": "text", "text": label}]},
]
messages.append(message)
batch = processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
labels = batch["input_ids"].clone() # Clone input IDs for labels
# Mask the tokens that we do not want to include in the loss computation
# -100 is ignored during categorical cross entropy loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == processor.tokenizer.image_token_id] = -100
labels[labels == processor.tokenizer.boi_token_id] = -100
labels[labels == processor.tokenizer.eoi_token_id] = -100
batch["labels"] = labels
return batch
def freeze_layers(model):
for name, param in model.named_parameters():
if "attn" in name:
param.requires_grad = True
else:
param.requires_grad = False
return model
def run_inference(val_dataset, processor, model, fname):
# infer before training
val_sample = random.choice(val_dataset)
image = val_sample["image"].convert("RGB")
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant with great geometry skills.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{
"type": "text",
"text": "How many intersection points are there in the image?",
},
],
},
]
inputs = processor.apply_chat_template(
message,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
generation = model.generate(**inputs, max_new_tokens=10, disable_compile=True)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
plt.imshow(image)
plt.axis("off")
plt.title(f"Pred: {decoded}")
plt.show()
plt.savefig(f"outputs_fine_tune/{fname}")
def main():
model_id = "google/gemma-3n-E2B-it"
processor = Gemma3nProcessor.from_pretrained(model_id)
# load the dataset
dataset_id = "ariG23498/intersection-dataset"
train_dataset = load_dataset(dataset_id, split="train")
val_dataset = load_dataset(dataset_id, split="validation")
# create data loader
partial_collate_fn = partial(collate_fn, processor=processor)
train_dataloader = DataLoader(
train_dataset,
batch_size=2,
shuffle=True,
num_workers=8,
drop_last=True,
collate_fn=partial_collate_fn,
pin_memory=True,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=2,
shuffle=False,
num_workers=8,
drop_last=True,
collate_fn=partial_collate_fn,
)
# load the model and optimizer
model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to("cuda")
run_inference(val_dataset, processor, model, "pred_before.png")
model = freeze_layers(model)
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)
# Start Training
accumulation_steps = 8
for idx, batch in tqdm(enumerate(train_dataloader)):
outputs = model(**batch.to(model.device))
loss = outputs.loss / accumulation_steps
if idx % 50 == 0:
val_loss = 0.0
with torch.no_grad():
count = 0
for val_batch in val_dataloader:
val_loss = val_loss + model(**val_batch.to(model.device)).loss
count = count + 1
val_loss = val_loss / count
print(
f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
)
run_inference(val_dataset, processor, model, f"infer_{idx}.png")
loss.backward()
if idx % 8 == 0:
optimizer.step()
optimizer.zero_grad()
if __name__ == "__main__":
main()
@@ -0,0 +1,425 @@
# -*- coding: utf-8 -*-
"""Gemma3n Fine-tuning on All Modalities.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1iEZUJuvKJpGU8t50BqfkiCQmGkaR6gd4
# Fine-tune Gemma3n on FineVideo
In this notebook, we will see how to fine-tune Gemma3n an videos with audios inside.
Using all three modalities is very costly compute-wise, so keep in mind that this is an educational tutorial to fit the model in 40GB VRAM.
"""
!pip install -U -q timm transformers trl peft datasets
import io
import os
import zipfile
import torch
from datasets import load_dataset
from PIL import Image
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
from trl import (
SFTConfig,
SFTTrainer,
)
"""## Download videos and preprocessing
FineVideo is a quite large dataset, we don't need a ton of examples, so we stream the dataset, check the duration and download the videos shorter than 30 secs.
"""
from datasets import load_dataset
import json
import os
dataset = load_dataset("HuggingFaceFV/finevideo", split="train", streaming=True)
os.makedirs("videos", exist_ok=True)
os.makedirs("metadata", exist_ok=True)
for idx, sample in enumerate(dataset):
data = sample["json"]
duration = data.get("duration_seconds", 0)
if duration < 30:
video_filename = f"videos/sample_{idx}.mp4"
with open(video_filename, 'wb') as video_file:
video_file.write(sample['mp4'])
json_filename = f"metadata/sample_{idx}.json"
with open(json_filename, 'w') as json_file:
json.dump(sample['json'], json_file)
print(f"Number of items in content/videos: {len(os.listdir('videos'))}")
"""In FineVideo some frames are dark so we downsample 6 frames and if we can't get meaningful videos we remove them."""
import cv2
from PIL import Image
import numpy as np
def is_dark(frame, threshold=10):
return np.max(frame) < threshold # all pixels are very close to 0
def downsample_video(video_path):
vidcap = cv2.VideoCapture(video_path)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = vidcap.get(cv2.CAP_PROP_FPS)
frames = []
# Generate 8 evenly spaced indices, skip first and last
full_indices = np.linspace(0, total_frames - 1, 8, dtype=int)[1:-1]
for i in full_indices:
found_valid = False
for offset in [0, -1, 1, -2, 2]: # Try nearby frames if original is dark
candidate_idx = i + offset
if 0 <= candidate_idx < total_frames:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, candidate_idx)
success, image = vidcap.read()
if success:
if not is_dark(image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(candidate_idx / fps, 2)
frames.append((pil_image, timestamp))
found_valid = True
break
if not found_valid:
print(f"Warning: Could not find non-dark frame near index {i}")
vidcap.release()
# If still fewer than 8, try to top off by scanning more frames
if len(frames) < 6:
print("Trying to top off with additional non-dark frames...")
idx = 0
while len(frames) < 8 and idx < total_frames:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, idx)
success, image = vidcap.read()
if success and not is_dark(image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(idx / fps, 2)
# Avoid adding duplicate timestamps
if not any(ts == timestamp for _, ts in frames):
frames.append((pil_image, timestamp))
idx += 1
return frames[:8] # Ensure exactly 8 frames
import os
import glob
def remove_dark_videos(video_dir, metadata_dir, audio_dir):
"""
Remove videos (and their metadata/audio files) if all frames are dark.
"""
video_paths = glob.glob(os.path.join(video_dir, "*.mp4"))
for video_path in video_paths:
filename = os.path.basename(video_path)
base_name = os.path.splitext(filename)[0]
frames = downsample_video(video_path)
if len(frames) < 6:
try:
os.remove(video_path)
print(f"Deleted: {video_path}")
except Exception as e:
print(f"Failed to delete {video_path}: {e}")
metadata_path = os.path.join(metadata_dir, f"{base_name}.json")
if os.path.exists(metadata_path):
os.remove(metadata_path)
# Remove audio
audio_path = os.path.join(audio_dir, f"{base_name}.wav")
if os.path.exists(audio_path):
os.remove(audio_path)
remove_dark_videos(
video_dir="videos",
metadata_dir="metadata",
audio_dir="audios"
)
"""Gemma-3n accepts video (image frames) and audio separately, so we strip audio from video."""
import os
import subprocess
video_dir = "videos"
audio_dir = "audios"
os.makedirs(audio_dir, exist_ok=True)
for filename in os.listdir(video_dir):
if not filename.endswith(".mp4"):
continue
idx = filename.split("_")[1].split(".")[0]
video_path = os.path.join(video_dir, filename)
audio_path = os.path.join(audio_dir, f"sample_{idx}.wav")
subprocess.run([
"ffmpeg", "-i", video_path,
"-q:a", "0", "-map", "a",
audio_path,
"-y"
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
"""Construct a new dataset with audio, video, metadata (video categories). This dataset is very cool, it has some questions and answers, captions and more so get creative if you have the GPU VRAM to do so. Here we solve an easier task for educational purposes."""
from datasets import Dataset
import json
def gen():
meta_dir = "metadata"
for filename in os.listdir(meta_dir):
if not filename.endswith(".json"):
continue
idx = filename.split("_")[1].split(".")[0]
if os.path.exists(f"videos/sample_{idx}.mp4"):
video_filename = f"sample_{idx}.mp4"
audio_filename = f"sample_{idx}.wav"
json_path = os.path.join(meta_dir, filename)
with open(json_path, "r") as f:
metadata = json.load(f)
yield {
"video": video_filename,
"audio": audio_filename,
"content_parent_category": metadata["content_parent_category"],
"sample_index": int(idx)
}
else:
pass
dataset = Dataset.from_generator(gen)
"""We will speed-up and downsample the audios to save space during training."""
import torchaudio
from torchaudio.transforms import Resample
import os
import torch
def preprocess_audio(audio_path, target_sample_rate=16000, max_duration_sec=5, speedup_factor=1.25):
waveform, sample_rate = torchaudio.load(audio_path)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
if sample_rate != target_sample_rate:
resampler = Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
waveform = resampler(waveform)
sample_rate = target_sample_rate
if speedup_factor > 1.0:
indices = torch.arange(0, waveform.shape[1], step=speedup_factor).long()
if indices[-1] >= waveform.shape[1]:
indices = indices[:-1]
waveform = waveform[:, indices]
max_length = int(target_sample_rate * max_duration_sec)
if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length]
torchaudio.save(audio_path, waveform, sample_rate)
for file_name in os.listdir("audios"):
if file_name.lower().endswith(".wav"):
audio_path = os.path.join("audios", file_name)
preprocess_audio(audio_path)
dataset = dataset.train_test_split(test_size=0.10, seed=42)
"""### Load the model
Make sure you have your Hugging Face token in your Colab secrets.
"""
model = Gemma3nForConditionalGeneration.from_pretrained(
"google/gemma-3n-E2B-it", torch_dtype=torch.bfloat16,
)
processor = AutoProcessor.from_pretrained(
"google/gemma-3n-E2B-it",
)
processor.tokenizer.padding_side = "right"
processor.tokenizer.all_special_ids
"""Write our dataset collator. We will train model to predict category of a video (which can be done easily). You can do much better things, for instance FineVideo has QnA section, you can train this model to do open-ended QnA if you have a big VRAM and a lot of patience. Open-ended tasks are harder to work with, and this notebook carries educational purposes on feeding different modalities.
In collator we also downsample videos to 6 frames, we have written the helper above. For better results you need more frames.
"""
def collate_fn(examples):
video_path = examples[0]["video"]
audio_path = examples[0]["audio"]
sample_idx = filename.split("_")[1].split(".")[0]
frames = downsample_video(f"videos/{video_path}")
text = "Based on the video, predict the category of it."
message = [
{
"role": "user",
"content": [
{"type": "text", "text": text}
],
},
]
# this is how video inference should be formatted in Gemma3n
for frame in frames:
image, timestamp = frame
message[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
timestamp = str(timestamp).replace(".", "_")
image.save(f"image_idx_{sample_idx}_{timestamp}.png")
message[0]["content"].append({"type": "image", "url": f"image_idx_{sample_idx}_{timestamp}.png"})
message[0]["content"].append({"type": "audio", "audio": f"audios/{audio_path}"})
message.append({"role": "assistant", "content": [{"type": "text", "text": examples[0]["content_parent_category"]}]})
inputs = processor.apply_chat_template(
message,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
).to(model.device)
labels = inputs["input_ids"].clone()
special_token_ids = processor.tokenizer.all_special_ids
special_token_ids_tensor = torch.tensor(special_token_ids, device=labels.device)
mask = torch.isin(labels, special_token_ids_tensor)
labels[mask] = -100
inputs["labels"] = labels
if torch.all(inputs["pixel_values"] == 0):
print("Frames are dark")
return inputs
"""## Training
We do LoRA fine-tuning again to save up on space.
"""
from peft import LoraConfig
peft_config = LoraConfig(
task_type="CAUSAL_LM",
r=16,
target_modules="all-linear",
lora_alpha=32,
lora_dropout=0.05,
bias="none",
use_rslora=False,
use_dora=False,
modules_to_save=None
)
model.gradient_checkpointing_disable()
model.config.use_cache = False
training_args = SFTConfig(
output_dir="/content/gemma-3n-finevideo",
eval_strategy='epoch',
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=4,
gradient_checkpointing=False,
learning_rate=1e-05,
num_train_epochs=3.0,
logging_steps=10,
save_steps=100,
bf16=True,
report_to=["tensorboard"],
dataset_kwargs={'skip_prepare_dataset': True},
remove_unused_columns=False,
max_seq_length=None,
push_to_hub=True,
dataloader_pin_memory=False,
)
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=dataset["train"],
eval_dataset=dataset["test"] if training_args.eval_strategy != "no" else None,
processing_class=processor.tokenizer,
peft_config=peft_config,
)
trainer.train()
"""Test the model with a video of snowboarding."""
!wget https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_8137.mp4
model = trainer.model # trainer has the adapter
"""Strip audio and downsample video."""
audio_path = "/content/test_audio.wav"
subprocess.run([
"ffmpeg", "-i", "/content/IMG_8137.mp4",
"-q:a", "0", "-map", "a",
f"{audio_path}",
"-y"
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
frames = downsample_video("/content/IMG_8137.mp4")
# repeat the chat template
text = "Based on the video, predict the category of it."
message = [
{
"role": "user",
"content": [
{"type": "text", "text": text}
],
},
]
for frame in frames:
image, timestamp = frame
message[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
timestamp = str(timestamp).replace(".", "_")
image.save(f"test_frame_{timestamp}.png")
message[0]["content"].append({"type": "image", "url": f"test_frame_{timestamp}.png"})
message[0]["content"].append({"type": "audio", "audio": f"{audio_path}"})
message
inputs = processor.apply_chat_template(
message,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
).to(model.device).to(model.dtype)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
"""Thanks a lot for reading! Keep training the model further with more data or unfreeze the layers for better performance 💗"""