docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,389 @@
|
||||
# Gemma4_(E2B)-Multimodal.ipynb — extracted cells
|
||||
# Source: https://github.com/huggingface/huggingface-gemma-recipes/blob/main/notebooks/Gemma4_(E2B)-Multimodal.ipynb
|
||||
|
||||
# ===== CELL 0 (markdown) =====
|
||||
# This notebook has vibe test examples to test image, text, audio capabilities of Gemma-4 model. To get started, let's install latest stable release of transformers.
|
||||
|
||||
# ===== CELL 1 (code) =====
|
||||
!pip install -U transformers
|
||||
|
||||
# ===== CELL 2 (markdown) =====
|
||||
# We can load model into `AutoModelForMultimodalLM` to make use of all capabilities.
|
||||
|
||||
# ===== CELL 3 (code) =====
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoModelForMultimodalLM, AutoProcessor
|
||||
#model_list = ["google/gemma-4-26B-A4B-it", "google/gemma-4-E4B-it",
|
||||
# "google/gemma-4-E2B-it", "google/gemma-4-31B-it"]
|
||||
model_id = "google/gemma-4-E2B-it"
|
||||
model = AutoModelForMultimodalLM.from_pretrained(model_id, device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
# ===== CELL 4 (markdown) =====
|
||||
# ## Code completion
|
||||
|
||||
# ===== CELL 5 (markdown) =====
|
||||
# We give Gemma-4 a website screenshot to reproduce the code.
|
||||
|
||||
# ===== CELL 6 (code) =====
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/landing_page.png",
|
||||
},
|
||||
{"type": "text", "text": "Write HTML code for this page."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=4000)
|
||||
|
||||
# ===== CELL 7 (code) =====
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
|
||||
print(result["content"])
|
||||
|
||||
# ===== CELL 8 (markdown) =====
|
||||
# ## Video Inference
|
||||
|
||||
# ===== CELL 9 (markdown) =====
|
||||
# We test Gemma-4 on video understanding. If you want to run this example with larger models which don't take audio input, disable `load_audio_from_video`.
|
||||
|
||||
# ===== CELL 10 (code) =====
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"},
|
||||
{"type": "text", "text": "What is happening in the video? What is the song about?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
load_audio_from_video=True,
|
||||
).to(model.device)
|
||||
output = model.generate(**inputs, max_new_tokens=200)
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
|
||||
# ===== CELL 11 (code) =====
|
||||
print(result["content"])
|
||||
|
||||
# ===== CELL 12 (markdown) =====
|
||||
# ## Multimodal Function Calling
|
||||
|
||||
# ===== CELL 13 (code) =====
|
||||
import re
|
||||
|
||||
WEATHER_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Gets the current weather for a specific location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The city name"},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
tools = [WEATHER_TOOL]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": [
|
||||
{"type": "image", "image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg"},
|
||||
{"type": "text", "text": "What is the city in this image? Check the weather there right now."},
|
||||
]},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tools=[WEATHER_TOOL],
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
).to(model.device)
|
||||
|
||||
# ===== CELL 14 (code) =====
|
||||
output = model.generate(**inputs, max_new_tokens=1000)
|
||||
|
||||
# ===== CELL 15 (code) =====
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
|
||||
# ===== CELL 16 (code) =====
|
||||
print(result["content"])
|
||||
|
||||
# ===== CELL 17 (markdown) =====
|
||||
# # Any-to-any inference
|
||||
|
||||
# ===== CELL 18 (markdown) =====
|
||||
# We can also run the model with `any-to-any` pipeline.
|
||||
|
||||
# ===== CELL 19 (code) =====
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline("any-to-any", model="google/gemma-4-e2b-it")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What is happening in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# ===== CELL 20 (code) =====
|
||||
pipe(messages)#, load_audio_from_video=True)
|
||||
|
||||
# ===== CELL 21 (code) =====
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"image": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What is happening in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
|
||||
# ===== CELL 22 (markdown) =====
|
||||
# # Object detection and pointing
|
||||
|
||||
# ===== CELL 23 (code) =====
|
||||
import re
|
||||
import torch
|
||||
from transformers.image_utils import load_image
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
import json
|
||||
|
||||
# ===== CELL 24 (code) =====
|
||||
image_url = "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png"
|
||||
image = load_image(image_url)
|
||||
|
||||
# ===== CELL 25 (code) =====
|
||||
def resize_to_48_multiple(image):
|
||||
w, h = image.size
|
||||
new_w = (w // 48) * 48
|
||||
new_h = (h // 48) * 48
|
||||
return image.crop((0, 0, new_w, new_h))
|
||||
|
||||
# ===== CELL 26 (code) =====
|
||||
def inputs_for_object_detection(image, what_object):
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "image": image},
|
||||
{"type": "text", "text": f"What's the bounding box for the {what_object} in the image?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
return inputs.to(model.device)
|
||||
|
||||
# ===== CELL 27 (code) =====
|
||||
def extract_json(text: str):
|
||||
text = text.strip()
|
||||
|
||||
text = re.sub(r"^```(?:json)?\s*", "", text)
|
||||
text = re.sub(r"\s*```$", "", text)
|
||||
|
||||
# Try direct parse first
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Fallback: extract first JSON object or array
|
||||
match = re.search(r'(\{.*\}|\[.*\])', text, re.DOTALL)
|
||||
if match:
|
||||
candidate = match.group(1)
|
||||
return json.loads(candidate)
|
||||
|
||||
raise ValueError("No valid JSON found")
|
||||
|
||||
# ===== CELL 28 (code) =====
|
||||
def detect_object(image_url, what_object):
|
||||
image = load_image(image_url)
|
||||
image = resize_to_48_multiple(image)
|
||||
inputs = inputs_for_object_detection(image, what_object)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
generated_outputs = model.generate(**inputs, max_new_tokens=1000, do_sample=False)
|
||||
generated = processor.decode(generated_outputs[0, input_len:])
|
||||
parsed_json = extract_json(generated)[0]
|
||||
return parsed_json
|
||||
|
||||
# ===== CELL 29 (code) =====
|
||||
def draw_pascal_voc_boxes(i, image, box, label, resize_shape=(1000,1000)):
|
||||
dpi = 72
|
||||
width, height = image.size
|
||||
fig, ax = plt.subplots(1, figsize=[width/dpi, height/dpi], tight_layout={'pad':0})
|
||||
|
||||
ax.imshow(image)
|
||||
|
||||
ymin, xmin, ymax, xmax = box
|
||||
re_h, re_w = resize_shape if resize_shape is not None else (height, width)
|
||||
xmin = (xmin / re_w) * width
|
||||
ymin = (ymin/ re_h) * height
|
||||
xmax = (xmax / re_w) * width
|
||||
ymax = (ymax/ re_h) * height
|
||||
|
||||
w = xmax - xmin
|
||||
h = ymax - ymin
|
||||
|
||||
rect = patches.Rectangle(
|
||||
(xmin, ymin),
|
||||
w,
|
||||
h,
|
||||
linewidth=10,
|
||||
edgecolor="green",
|
||||
facecolor="none"
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
|
||||
if label is not None:
|
||||
ax.text(xmin, ymin-25, label, fontsize=24, bbox=dict(facecolor="yellow", alpha=0.5))
|
||||
|
||||
plt.axis("off")
|
||||
plt.savefig(f"boxes_{i}.png")
|
||||
plt.close(fig)
|
||||
display(fig)
|
||||
|
||||
# ===== CELL 30 (code) =====
|
||||
def display_detected_object(image_url, what_object):
|
||||
image = load_image(image_url)
|
||||
image = resize_to_48_multiple(image)
|
||||
detection = detect_object(image_url, what_object)
|
||||
box = detection["box_2d"]
|
||||
label = detection.get("label", f"{what_object}")
|
||||
draw_pascal_voc_boxes("1000", image, box, label)
|
||||
|
||||
# ===== CELL 31 (code) =====
|
||||
display_detected_object("https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png", "bike")
|
||||
|
||||
# ===== CELL 32 (markdown) =====
|
||||
# ## Captioning
|
||||
|
||||
# ===== CELL 33 (code) =====
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bird.png"},
|
||||
{"type": "text", "text": "Write single detailed caption for this image."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=512)
|
||||
input_len = inputs.input_ids.shape[-1]
|
||||
generated_text_ids = output[0][input_len:]
|
||||
generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)
|
||||
result = processor.parse_response(generated_text)
|
||||
print(result["content"])
|
||||
|
||||
# ===== CELL 34 (markdown) =====
|
||||
# ## Audio Understanding
|
||||
|
||||
# ===== CELL 35 (code) =====
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3"},
|
||||
{"type": "text", "text": "Can you describe this audio in detail?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1000,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
@@ -0,0 +1,595 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"This notebook has vibe test examples to test image, text, audio capabilities of Gemma-4 model. To get started, let's install latest stable release of transformers."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install -U transformers"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We can load model into `AutoModelForMultimodalLM` to make use of all capabilities."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"from transformers import AutoModelForMultimodalLM, AutoProcessor\n",
|
||||
"#model_list = [\"google/gemma-4-26B-A4B-it\", \"google/gemma-4-E4B-it\",\n",
|
||||
"# \"google/gemma-4-E2B-it\", \"google/gemma-4-31B-it\"]\n",
|
||||
"model_id = \"google/gemma-4-E2B-it\"\n",
|
||||
"model = AutoModelForMultimodalLM.from_pretrained(model_id, device_map=\"auto\")\n",
|
||||
"processor = AutoProcessor.from_pretrained(model_id)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Code completion"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We give Gemma-4 a website screenshot to reproduce the code."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"image\",\n",
|
||||
" \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/landing_page.png\",\n",
|
||||
" },\n",
|
||||
" {\"type\": \"text\", \"text\": \"Write HTML code for this page.\"},\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" enable_thinking=True,\n",
|
||||
").to(model.device)\n",
|
||||
"\n",
|
||||
"output = model.generate(**inputs, max_new_tokens=4000)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"input_len = inputs.input_ids.shape[-1]\n",
|
||||
"generated_text_ids = output[0][input_len:]\n",
|
||||
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
|
||||
"result = processor.parse_response(generated_text)\n",
|
||||
"\n",
|
||||
"print(result[\"content\"])"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Video Inference"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We test Gemma-4 on video understanding. If you want to run this example with larger models which don't take audio input, disable `load_audio_from_video`."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"video\", \"url\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4\"},\n",
|
||||
" {\"type\": \"text\", \"text\": \"What is happening in the video? What is the song about?\"},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" load_audio_from_video=True,\n",
|
||||
").to(model.device)\n",
|
||||
"output = model.generate(**inputs, max_new_tokens=200)\n",
|
||||
"input_len = inputs.input_ids.shape[-1]\n",
|
||||
"generated_text_ids = output[0][input_len:]\n",
|
||||
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
|
||||
"result = processor.parse_response(generated_text)\n"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(result[\"content\"])"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Multimodal Function Calling"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import re\n",
|
||||
"\n",
|
||||
"WEATHER_TOOL = {\n",
|
||||
" \"type\": \"function\",\n",
|
||||
" \"function\": {\n",
|
||||
" \"name\": \"get_weather\",\n",
|
||||
" \"description\": \"Gets the current weather for a specific location.\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"city\": {\"type\": \"string\", \"description\": \"The city name\"},\n",
|
||||
" },\n",
|
||||
" \"required\": [\"city\"],\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"tools = [WEATHER_TOOL]\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" {\"role\": \"user\", \"content\": [\n",
|
||||
" {\"type\": \"image\", \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/thailand.jpg\"},\n",
|
||||
" {\"type\": \"text\", \"text\": \"What is the city in this image? Check the weather there right now.\"},\n",
|
||||
" ]},\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tools=[WEATHER_TOOL],\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" enable_thinking=True,\n",
|
||||
").to(model.device)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"output = model.generate(**inputs, max_new_tokens=1000)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"input_len = inputs.input_ids.shape[-1]\n",
|
||||
"generated_text_ids = output[0][input_len:]\n",
|
||||
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
|
||||
"result = processor.parse_response(generated_text)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(result[\"content\"])"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Any-to-any inference"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"We can also run the model with `any-to-any` pipeline."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"\n",
|
||||
"pipe = pipeline(\"any-to-any\", model=\"google/gemma-4-e2b-it\")\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"video\",\n",
|
||||
" \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4\",\n",
|
||||
" },\n",
|
||||
" {\"type\": \"text\", \"text\": \"What is happening in this video?\"},\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
"]\n"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"pipe(messages)#, load_audio_from_video=True)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"video\",\n",
|
||||
" \"image\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/rockets.mp4\",\n",
|
||||
" },\n",
|
||||
" {\"type\": \"text\", \"text\": \"What is happening in this video?\"},\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\"\n",
|
||||
")\n",
|
||||
"inputs = inputs.to(model.device)\n",
|
||||
"\n",
|
||||
"generated_ids = model.generate(**inputs, max_new_tokens=128)\n",
|
||||
"generated_ids_trimmed = [\n",
|
||||
" out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n",
|
||||
"]\n",
|
||||
"output_text = processor.batch_decode(\n",
|
||||
" generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
|
||||
")\n",
|
||||
"print(output_text)\n"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Object detection and pointing"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import re\n",
|
||||
"import torch\n",
|
||||
"from transformers.image_utils import load_image\n",
|
||||
"from PIL import Image\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.patches as patches\n",
|
||||
"import json"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"image_url = \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png\"\n",
|
||||
"image = load_image(image_url)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def resize_to_48_multiple(image):\n",
|
||||
" w, h = image.size\n",
|
||||
" new_w = (w // 48) * 48\n",
|
||||
" new_h = (h // 48) * 48\n",
|
||||
" return image.crop((0, 0, new_w, new_h))"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def inputs_for_object_detection(image, what_object):\n",
|
||||
" messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\", \"content\": [\n",
|
||||
" {\"type\": \"image\", \"image\": image},\n",
|
||||
" {\"type\": \"text\", \"text\": f\"What's the bounding box for the {what_object} in the image?\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" enable_thinking=False,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return inputs.to(model.device)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def extract_json(text: str):\n",
|
||||
" text = text.strip()\n",
|
||||
"\n",
|
||||
" text = re.sub(r\"^```(?:json)?\\s*\", \"\", text)\n",
|
||||
" text = re.sub(r\"\\s*```$\", \"\", text)\n",
|
||||
"\n",
|
||||
" # Try direct parse first\n",
|
||||
" try:\n",
|
||||
" return json.loads(text)\n",
|
||||
" except json.JSONDecodeError:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" # Fallback: extract first JSON object or array\n",
|
||||
" match = re.search(r'(\\{.*\\}|\\[.*\\])', text, re.DOTALL)\n",
|
||||
" if match:\n",
|
||||
" candidate = match.group(1)\n",
|
||||
" return json.loads(candidate)\n",
|
||||
"\n",
|
||||
" raise ValueError(\"No valid JSON found\")"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def detect_object(image_url, what_object):\n",
|
||||
" image = load_image(image_url)\n",
|
||||
" image = resize_to_48_multiple(image)\n",
|
||||
" inputs = inputs_for_object_detection(image, what_object)\n",
|
||||
" input_len = inputs[\"input_ids\"].shape[-1]\n",
|
||||
" generated_outputs = model.generate(**inputs, max_new_tokens=1000, do_sample=False)\n",
|
||||
" generated = processor.decode(generated_outputs[0, input_len:])\n",
|
||||
" parsed_json = extract_json(generated)[0]\n",
|
||||
" return parsed_json"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def draw_pascal_voc_boxes(i, image, box, label, resize_shape=(1000,1000)):\n",
|
||||
" dpi = 72\n",
|
||||
" width, height = image.size\n",
|
||||
" fig, ax = plt.subplots(1, figsize=[width/dpi, height/dpi], tight_layout={'pad':0})\n",
|
||||
"\n",
|
||||
" ax.imshow(image)\n",
|
||||
"\n",
|
||||
" ymin, xmin, ymax, xmax = box\n",
|
||||
" re_h, re_w = resize_shape if resize_shape is not None else (height, width)\n",
|
||||
" xmin = (xmin / re_w) * width\n",
|
||||
" ymin = (ymin/ re_h) * height\n",
|
||||
" xmax = (xmax / re_w) * width\n",
|
||||
" ymax = (ymax/ re_h) * height\n",
|
||||
"\n",
|
||||
" w = xmax - xmin\n",
|
||||
" h = ymax - ymin\n",
|
||||
"\n",
|
||||
" rect = patches.Rectangle(\n",
|
||||
" (xmin, ymin),\n",
|
||||
" w,\n",
|
||||
" h,\n",
|
||||
" linewidth=10,\n",
|
||||
" edgecolor=\"green\",\n",
|
||||
" facecolor=\"none\"\n",
|
||||
" )\n",
|
||||
" ax.add_patch(rect)\n",
|
||||
"\n",
|
||||
" if label is not None:\n",
|
||||
" ax.text(xmin, ymin-25, label, fontsize=24, bbox=dict(facecolor=\"yellow\", alpha=0.5))\n",
|
||||
"\n",
|
||||
" plt.axis(\"off\")\n",
|
||||
" plt.savefig(f\"boxes_{i}.png\")\n",
|
||||
" plt.close(fig)\n",
|
||||
" display(fig)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def display_detected_object(image_url, what_object):\n",
|
||||
" image = load_image(image_url)\n",
|
||||
" image = resize_to_48_multiple(image)\n",
|
||||
" detection = detect_object(image_url, what_object)\n",
|
||||
" box = detection[\"box_2d\"]\n",
|
||||
" label = detection.get(\"label\", f\"{what_object}\")\n",
|
||||
" draw_pascal_voc_boxes(\"1000\", image, box, label)"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"display_detected_object(\"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bike.png\", \"bike\")"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"##\u00a0Captioning"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"image\", \"url\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/bird.png\"},\n",
|
||||
" {\"type\": \"text\", \"text\": \"Write single detailed caption for this image.\"},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
").to(model.device)\n",
|
||||
"\n",
|
||||
"output = model.generate(**inputs, max_new_tokens=512)\n",
|
||||
"input_len = inputs.input_ids.shape[-1]\n",
|
||||
"generated_text_ids = output[0][input_len:]\n",
|
||||
"generated_text = processor.decode(generated_text_ids, skip_special_tokens=True)\n",
|
||||
"result = processor.parse_response(generated_text)\n",
|
||||
"print(result[\"content\"])"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Audio Understanding"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"audio\", \"url\": \"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3\"},\n",
|
||||
" {\"type\": \"text\", \"text\": \"Can you describe this audio in detail?\"},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
").to(model.device)\n",
|
||||
"\n",
|
||||
"output = model.generate(\n",
|
||||
" **inputs,\n",
|
||||
" max_new_tokens=1000,\n",
|
||||
" do_sample=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(processor.decode(output[0], skip_special_tokens=True))\n"
|
||||
],
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl",
|
||||
# "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
|
||||
"""
|
||||
GRPO training with OpenEnv's CARLA environment for VLMs (Vision Language Models).
|
||||
|
||||
This script uses `environment_factory` with multimodal tool responses: each tool action
|
||||
returns a camera image from the vehicle alongside the text scene description, allowing the
|
||||
VLM to see the driving scene visually after each action.
|
||||
|
||||
The CARLA environment simulates an emergency driving scenario where pedestrians are ahead
|
||||
and the model must learn to observe the scene and take the correct action (e.g., swerve
|
||||
to an empty lane) to minimize casualties.
|
||||
|
||||
Setup:
|
||||
```sh
|
||||
pip install "openenv-carla-env @ git+https://huggingface.co/spaces/sergiopaniego/carla_env"
|
||||
```
|
||||
|
||||
Usage (requires at least 2 CARLA Spaces, each supports only 1 concurrent connection):
|
||||
```sh
|
||||
python examples/scripts/openenv/carla_vlm.py \
|
||||
--env-urls https://server1.hf.space https://server2.hf.space
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from carla_env import CarlaAction, CarlaEnv
|
||||
from datasets import Dataset
|
||||
from PIL import Image
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run GRPO VLM training with CARLA environment.")
|
||||
parser.add_argument("--model", type=str, default="google/gemma-4-E2B-it")
|
||||
parser.add_argument(
|
||||
"--env-urls",
|
||||
type=str,
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="URLs for CARLA environment servers. At least 2 required (1 Space = 1 connection).",
|
||||
)
|
||||
parser.add_argument("--dataset-size", type=int, default=1000)
|
||||
parser.add_argument("--max-completion-length", type=int, default=3072)
|
||||
parser.add_argument("--per-device-train-batch-size", type=int, default=None, help="Defaults to len(env-urls).")
|
||||
parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
|
||||
parser.add_argument("--max-steps", type=int, default=100)
|
||||
parser.add_argument("--image-size", type=int, default=256, help="Resize camera images to this size. 0 to disable.")
|
||||
parser.add_argument("--trackio-space-id", type=str, default=None, help="Trackio Space ID for logging.")
|
||||
parser.add_argument("--use-lora", action="store_true", help="Use LoRA for memory-efficient training.")
|
||||
parser.add_argument("--lora-r", type=int, default=128, help="LoRA rank.")
|
||||
parser.add_argument("--lora-alpha", type=int, default=256, help="LoRA alpha.")
|
||||
parser.add_argument(
|
||||
"--lora-target-modules",
|
||||
type=str,
|
||||
default="llm-only",
|
||||
help="LoRA target modules. Use 'llm-only' to skip vision encoder, 'all-linear' for all.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", type=float, default=5e-6, help="Learning rate. Default 5e-6 (good for LoRA r=128)."
|
||||
)
|
||||
parser.add_argument("--hub-model-id", type=str, default=None)
|
||||
parser.add_argument("--hub-private-repo", action="store_true", help="Make the Hub repo private.")
|
||||
parser.add_argument("--run-name", type=str, default=None)
|
||||
parser.add_argument("--report-to", type=str, default="trackio", help="Logging backend: wandb, trackio, none.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
SIM_TICKS = 10
|
||||
|
||||
|
||||
def reward_func(completions, environments, **kwargs):
|
||||
rewards = []
|
||||
for i, (comp, env) in enumerate(zip(completions, environments, strict=False)):
|
||||
# Advance to episode end to capture rubric_reward
|
||||
try:
|
||||
final = env._advance_until_done()
|
||||
if final and final.observation.rubric_reward:
|
||||
env.reward = final.observation.rubric_reward
|
||||
except Exception as e:
|
||||
print(f"[WARN] _advance_until_done failed for gen={i}: {e}")
|
||||
r = env.reward
|
||||
tools = [
|
||||
msg["tool_calls"][0]["function"]["name"] for msg in comp if isinstance(msg, dict) and msg.get("tool_calls")
|
||||
]
|
||||
rewards.append(r)
|
||||
print(f"[DEBUG reward] gen={i} tools={tools} env_reward={env.reward} total={r}")
|
||||
return rewards
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
env_url_iter = iter(args.env_urls)
|
||||
image_size = args.image_size
|
||||
|
||||
prompt = """\
|
||||
You control an autonomous vehicle in an emergency. There are pedestrians ahead and you must \
|
||||
decide what to do immediately.
|
||||
|
||||
You will see a camera image from the vehicle after each action. Use the visual information
|
||||
along with the scene description to decide your next action.
|
||||
|
||||
You have the following tools available:
|
||||
- `observe`: Advance time and get a new observation of the scene with a camera image.
|
||||
- `emergency_stop`: Apply maximum braking to stop the vehicle.
|
||||
- `lane_change(direction)`: Change lane to the left or right. Direction must be "left" or "right".
|
||||
|
||||
Make one tool call at a time, wait for the result, then decide your next action.
|
||||
Observe the scene first, then decide the best course of action to minimize harm.
|
||||
Consider all available actions - sometimes avoiding the obstacle by changing lanes \
|
||||
is safer than stopping in its path."""
|
||||
|
||||
dataset = Dataset.from_dict({"prompt": [[{"role": "user", "content": prompt}] for _ in range(args.dataset_size)]})
|
||||
|
||||
class CarlaVLMEnv:
|
||||
def __init__(self):
|
||||
self.url = next(env_url_iter)
|
||||
self.client = CarlaEnv(base_url=self.url, connect_timeout_s=30, message_timeout_s=120)
|
||||
self.reward = 0.0
|
||||
|
||||
@staticmethod
|
||||
def _describe(obs) -> str:
|
||||
parts = []
|
||||
parts.append(f"Speed: {obs.speed_kmh:.1f} km/h.")
|
||||
if obs.nearby_actors:
|
||||
for actor in obs.nearby_actors:
|
||||
parts.append(f"- {actor.get('type', 'actor')} at {actor.get('distance', '?')}m")
|
||||
else:
|
||||
parts.append("No nearby actors detected.")
|
||||
if obs.collision_detected:
|
||||
parts.append(f"COLLISION detected with {obs.collided_with or 'unknown'}!")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _decode_image(camera_image_b64, target_size):
|
||||
"""Decode base64 JPEG image and optionally resize."""
|
||||
img_bytes = base64.b64decode(camera_image_b64)
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
if target_size > 0:
|
||||
img.thumbnail((target_size, target_size), Image.LANCZOS)
|
||||
return img
|
||||
|
||||
def _format_multimodal(self, obs) -> list:
|
||||
"""Format observation as multimodal content blocks (camera image + text)."""
|
||||
content = []
|
||||
if obs.camera_image is not None:
|
||||
img = self._decode_image(obs.camera_image, image_size)
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": self._describe(obs)})
|
||||
return content
|
||||
|
||||
def _advance(self, ticks: int = SIM_TICKS):
|
||||
result = None
|
||||
for _ in range(ticks):
|
||||
result = self.client.step(CarlaAction(action_type="observe"))
|
||||
if result.done:
|
||||
break
|
||||
return result
|
||||
|
||||
def _advance_until_done(self, max_ticks: int = 50):
|
||||
"""Advance the simulation until the episode ends."""
|
||||
result = None
|
||||
for _ in range(max_ticks):
|
||||
result = self.client.step(CarlaAction(action_type="observe"))
|
||||
if result.done:
|
||||
break
|
||||
return result
|
||||
|
||||
def _advance_and_capture(self, ticks: int = SIM_TICKS):
|
||||
"""Advance the simulation, then capture an image of the current state."""
|
||||
result = self._advance(ticks)
|
||||
capture_result = self.client.step(CarlaAction(action_type="capture_image"))
|
||||
result.observation.camera_image = capture_result.observation.camera_image
|
||||
return result
|
||||
|
||||
def reset(self, **kwargs) -> str | None:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
result = self.client.reset(scenario_name="trolley_micro_escape_exists")
|
||||
self.reward = 0.0
|
||||
return self._describe(result.observation)
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
raise
|
||||
print(f"[WARN] reset failed (attempt {attempt + 1}/3): {e}. Reconnecting...")
|
||||
self.client = CarlaEnv(base_url=self.url, connect_timeout_s=30, message_timeout_s=120)
|
||||
|
||||
def observe(self) -> list:
|
||||
"""
|
||||
Get the current scene with a camera image and description.
|
||||
|
||||
Returns:
|
||||
The camera image and scene description with vehicle state and nearby actors.
|
||||
"""
|
||||
result = self._advance_and_capture()
|
||||
self.reward = result.observation.rubric_reward or 0.0
|
||||
return self._format_multimodal(result.observation)
|
||||
|
||||
def emergency_stop(self) -> list:
|
||||
"""
|
||||
Apply maximum braking to stop the vehicle.
|
||||
|
||||
Returns:
|
||||
The camera image and scene description after braking.
|
||||
"""
|
||||
self.client.step(CarlaAction(action_type="emergency_stop"))
|
||||
result = self._advance_and_capture()
|
||||
self.reward = result.observation.rubric_reward or 0.0
|
||||
print(f"[DEBUG env] emergency_stop: done={result.done}, reward={self.reward}")
|
||||
return self._format_multimodal(result.observation)
|
||||
|
||||
def lane_change(self, direction: str) -> list:
|
||||
"""
|
||||
Change lane to avoid obstacles.
|
||||
|
||||
Args:
|
||||
direction: Direction to change lane, either "left" or "right".
|
||||
|
||||
Returns:
|
||||
The camera image and scene description after changing lane.
|
||||
"""
|
||||
self.client.step(CarlaAction(action_type="lane_change", lane_direction=direction))
|
||||
result = self._advance_and_capture()
|
||||
self.reward = result.observation.rubric_reward or 0.0
|
||||
print(f"[DEBUG env] lane_change({direction}): done={result.done}, reward={self.reward}")
|
||||
return self._format_multimodal(result.observation)
|
||||
|
||||
peft_config = None
|
||||
if args.use_lora:
|
||||
from peft import LoraConfig
|
||||
|
||||
if args.lora_target_modules == "llm-only":
|
||||
target_modules = "all-linear"
|
||||
exclude_modules = ["vision_tower", "multi_modal_projector"]
|
||||
else:
|
||||
target_modules = args.lora_target_modules
|
||||
exclude_modules = None
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
target_modules=target_modules,
|
||||
exclude_modules=exclude_modules,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=args.model,
|
||||
train_dataset=dataset,
|
||||
reward_funcs=reward_func,
|
||||
peft_config=peft_config,
|
||||
args=GRPOConfig(
|
||||
chat_template_kwargs={"enable_thinking": False},
|
||||
log_completions=True,
|
||||
logging_steps=2,
|
||||
num_completions_to_print=1,
|
||||
max_completion_length=args.max_completion_length,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size or len(args.env_urls),
|
||||
steps_per_generation=1,
|
||||
num_generations=len(args.env_urls),
|
||||
max_tool_calling_iterations=10,
|
||||
learning_rate=args.learning_rate,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
max_steps=args.max_steps,
|
||||
push_to_hub=args.hub_model_id is not None,
|
||||
hub_model_id=args.hub_model_id,
|
||||
hub_private_repo=args.hub_private_repo,
|
||||
run_name=args.run_name,
|
||||
report_to=args.report_to,
|
||||
trackio_space_id=args.trackio_space_id,
|
||||
),
|
||||
environment_factory=CarlaVLMEnv,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,184 @@
|
||||
import os
|
||||
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from matplotlib import pyplot as plt
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import Gemma3nForConditionalGeneration, Gemma3nProcessor
|
||||
|
||||
|
||||
def collate_fn(examples, processor):
|
||||
messages = list()
|
||||
for sample in examples:
|
||||
audio = sample["audio"]["array"]
|
||||
label = str(sample["text"])
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant that transcribes speech accurately.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": audio},
|
||||
{"type": "text", "text": "Please transcribe this audio."},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": label}]},
|
||||
]
|
||||
messages.append(message)
|
||||
|
||||
batch = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
labels = batch["input_ids"].clone() # Clone input IDs for labels
|
||||
# Mask the tokens that we do not want to include in the loss computation
|
||||
# -100 is ignored during categorical cross entropy loss computation
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == processor.tokenizer.audio_token_id] = -100
|
||||
labels[labels == processor.tokenizer.image_token_id] = -100
|
||||
labels[labels == processor.tokenizer.boi_token_id] = -100
|
||||
labels[labels == processor.tokenizer.eoi_token_id] = -100
|
||||
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def freeze_layers(model):
|
||||
for name, param in model.named_parameters():
|
||||
if "attn" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
def run_inference(val_dataset, processor, model, fname):
|
||||
# infer before training
|
||||
val_sample = random.choice(val_dataset)
|
||||
audio = val_sample["audio"]["array"]
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant that transcribes speech accurately.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio": audio},
|
||||
{"type": "text", "text": "Please transcribe this audio."},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
message,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device, dtype=torch.bfloat16)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
with torch.no_grad():
|
||||
generation = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
|
||||
generation = generation[0][input_len:]
|
||||
|
||||
decoded = processor.decode(generation, skip_special_tokens=True)
|
||||
|
||||
print(f"Audio transcription: {decoded}")
|
||||
print(f"Label: {val_sample['text']}")
|
||||
|
||||
|
||||
def main():
|
||||
model_id = "google/gemma-3n-E2B-it"
|
||||
processor = Gemma3nProcessor.from_pretrained(model_id)
|
||||
|
||||
# Load and split the dataset.
|
||||
ds_full = load_dataset("AdrienB134/Emilia-dataset-french-split", split="fr")
|
||||
split_ds = ds_full.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset = split_ds["train"].select(range(10000))
|
||||
val_dataset = split_ds["test"].select(range(100))
|
||||
|
||||
# create data loader
|
||||
partial_collate_fn = partial(collate_fn, processor=processor)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
num_workers=8,
|
||||
drop_last=True,
|
||||
collate_fn=partial_collate_fn,
|
||||
pin_memory=True,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
drop_last=True,
|
||||
collate_fn=partial_collate_fn,
|
||||
)
|
||||
|
||||
# load the model and optimizer
|
||||
model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to(
|
||||
"cuda", dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
run_inference(val_dataset, processor, model, "pred_before.png")
|
||||
|
||||
model = freeze_layers(model)
|
||||
|
||||
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
|
||||
optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)
|
||||
|
||||
# Start Training
|
||||
accumulation_steps = 8
|
||||
for idx, batch in tqdm(enumerate(train_dataloader)):
|
||||
outputs = model(**batch.to(model.device, dtype=torch.bfloat16))
|
||||
loss = outputs.loss / accumulation_steps
|
||||
if idx % 100 == 0:
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
count = 0
|
||||
for val_batch in tqdm(val_dataloader, desc="Validation"):
|
||||
val_loss = (
|
||||
val_loss
|
||||
+ model(**val_batch.to(model.device, dtype=torch.bfloat16)).loss
|
||||
)
|
||||
count = count + 1
|
||||
val_loss = val_loss / count
|
||||
print(
|
||||
f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
|
||||
)
|
||||
run_inference(val_dataset, processor, model, f"infer_{idx}.png")
|
||||
|
||||
loss.backward()
|
||||
if idx % 8 == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
Train Gemma-3n on various vision-language datasets including intersection-dataset.
|
||||
|
||||
For Gemma-3n with intersection dataset:
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
sft_vlm_gemma3n.py \
|
||||
--dataset_name ariG23498/intersection-dataset \
|
||||
--model_name_or_path google/gemma-3n-E2B-it \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--output_dir gemma-3n-E2B-it-trl-sft-intersection \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules all-linear \
|
||||
--attn_implementation eager
|
||||
|
||||
Train Gemma-3n on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image).
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
sft_vlm_gemma3n.py \
|
||||
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
|
||||
--model_name_or_path google/gemma-3-4b-it \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--output_dir gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules all-linear \
|
||||
--attn_implementation eager
|
||||
|
||||
Train Gemma-3n on the FanqingM/MMIU-Benchmark dataset (multi-image).
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
sft_vlm_gemma3n.py \
|
||||
--dataset_name FanqingM/MMIU-Benchmark \
|
||||
--dataset_train_split test \
|
||||
--model_name_or_path google/gemma-3-4b-it \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules all-linear
|
||||
--attn_implementation eager
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
from PIL import Image
|
||||
from transformers import (AutoModelForImageTextToText, AutoProcessor,
|
||||
Gemma3nForConditionalGeneration)
|
||||
from trl import (ModelConfig, ScriptArguments, SFTConfig, SFTTrainer,
|
||||
TrlParser, get_kbit_device_map, get_quantization_config)
|
||||
|
||||
|
||||
def my_get_peft_config(model_args: ModelConfig):
|
||||
"""A version of get_peft_config that handles comma-separated target modules"""
|
||||
if model_args.use_peft is False:
|
||||
return None
|
||||
|
||||
# Import here to avoid issues if PEFT is not available
|
||||
try:
|
||||
from peft import LoraConfig
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
|
||||
"Make sure to run `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Fix the target_modules to be a list if it's a comma-separated string
|
||||
target_modules = model_args.lora_target_modules
|
||||
if isinstance(target_modules, str) and target_modules != "all-linear":
|
||||
# Convert comma-separated string to list
|
||||
target_modules = [module.strip() for module in target_modules.split(",")]
|
||||
|
||||
peft_config = LoraConfig(
|
||||
task_type=model_args.lora_task_type,
|
||||
r=model_args.lora_r,
|
||||
target_modules=target_modules,
|
||||
lora_alpha=model_args.lora_alpha,
|
||||
lora_dropout=model_args.lora_dropout,
|
||||
bias="none",
|
||||
use_rslora=model_args.use_rslora,
|
||||
use_dora=model_args.use_dora,
|
||||
modules_to_save=model_args.lora_modules_to_save,
|
||||
)
|
||||
|
||||
return peft_config
|
||||
|
||||
|
||||
# For intersection dataset processing
|
||||
def format_intersection_data(samples: dict) -> dict[str, list]:
|
||||
"""Format intersection dataset to match expected message format"""
|
||||
formatted_samples = {"messages": []}
|
||||
for idx in range(len(samples["image"])):
|
||||
image = samples["image"][idx].convert("RGB")
|
||||
label = str(samples["label"][idx])
|
||||
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant with great geometry skills.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": image},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "How many intersection points are there in the image?",
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": label}]},
|
||||
]
|
||||
formatted_samples["messages"].append(message)
|
||||
return formatted_samples
|
||||
|
||||
|
||||
# For multi-image example
|
||||
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
||||
image_inputs = []
|
||||
for msg in messages:
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
for element in content:
|
||||
if isinstance(element, dict) and (
|
||||
"image" in element or element.get("type") == "image"
|
||||
):
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element
|
||||
if image is not None:
|
||||
# Handle dictionary with bytes
|
||||
if isinstance(image, dict) and "bytes" in image:
|
||||
pil_image = Image.open(io.BytesIO(image["bytes"]))
|
||||
image_inputs.append(pil_image.convert("RGB"))
|
||||
# Handle PIL Image objects
|
||||
elif hasattr(image, "convert"):
|
||||
image_inputs.append(image.convert("RGB"))
|
||||
return image_inputs
|
||||
|
||||
|
||||
def format_data(samples: dict) -> dict[str, list]:
|
||||
formatted_samples = {"messages": []}
|
||||
for cont in range(len(samples["question"])):
|
||||
images = []
|
||||
for img_path in samples["input_image_path"][cont]:
|
||||
try:
|
||||
with open(img_path, "rb") as f:
|
||||
img_bytes = f.read()
|
||||
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
||||
images.append({"type": "image", "image": image})
|
||||
except Exception as e:
|
||||
print(f"Error processing image {img_path}: {e}")
|
||||
continue
|
||||
|
||||
formatted_samples["messages"].append(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": samples["context"][cont]}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": images
|
||||
+ [{"type": "text", "text": samples["question"][cont]}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": samples["output"][cont]}],
|
||||
},
|
||||
]
|
||||
)
|
||||
return formatted_samples
|
||||
|
||||
|
||||
# For multi-image example
|
||||
def prepare_dataset(
|
||||
dataset: DatasetDict, dataset_name: str, dataset_train_split: str
|
||||
) -> DatasetDict:
|
||||
all_files = list_repo_files(dataset_name, repo_type="dataset")
|
||||
zip_files = [f for f in all_files if f.endswith(".zip")]
|
||||
|
||||
for zip_filename in zip_files:
|
||||
zip_path = hf_hub_download(
|
||||
repo_id=dataset_name, filename=zip_filename, repo_type="dataset"
|
||||
)
|
||||
extract_folder = zip_filename.replace(".zip", "")
|
||||
os.makedirs(extract_folder, exist_ok=True)
|
||||
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_folder)
|
||||
|
||||
dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
################
|
||||
# Model, Tokenizer & Processor
|
||||
################
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
processor.tokenizer.padding_side = "right"
|
||||
|
||||
# Use appropriate model class based on model name
|
||||
if "gemma-3n" in model_args.model_name_or_path.lower():
|
||||
model = Gemma3nForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def collate_fn(examples):
|
||||
texts = []
|
||||
images_list = []
|
||||
|
||||
for example in examples:
|
||||
# Apply chat template to get text
|
||||
text = processor.apply_chat_template(
|
||||
example["messages"], tokenize=False, add_generation_prompt=False
|
||||
).strip()
|
||||
texts.append(text)
|
||||
|
||||
# Extract images
|
||||
if "images" in example: # single-image case
|
||||
images = [img.convert("RGB") for img in example["images"]]
|
||||
else: # multi-image case or intersection dataset
|
||||
images = process_vision_info(example["messages"])
|
||||
images_list.append(images)
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(
|
||||
text=texts, images=images_list, return_tensors="pt", padding=True
|
||||
)
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
labels = batch["input_ids"].clone()
|
||||
|
||||
# Mask tokens for Gemma3n model
|
||||
if "gemma-3n" in model_args.model_name_or_path.lower():
|
||||
# Use Gemma3n specific token masking
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
if hasattr(processor.tokenizer, "image_token_id"):
|
||||
labels[labels == processor.tokenizer.image_token_id] = -100
|
||||
if hasattr(processor.tokenizer, "boi_token_id"):
|
||||
labels[labels == processor.tokenizer.boi_token_id] = -100
|
||||
if hasattr(processor.tokenizer, "eoi_token_id"):
|
||||
labels[labels == processor.tokenizer.eoi_token_id] = -100
|
||||
else:
|
||||
# Original masking for other models
|
||||
image_token_id = [
|
||||
processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.tokenizer.special_tokens_map["boi_token"]
|
||||
)
|
||||
]
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == image_token_id] = -100
|
||||
labels[labels == 262144] = -100
|
||||
|
||||
batch["labels"] = labels
|
||||
return batch
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
|
||||
# Handle different dataset formats
|
||||
if script_args.dataset_name == "FanqingM/MMIU-Benchmark":
|
||||
dataset = prepare_dataset(
|
||||
dataset, script_args.dataset_name, script_args.dataset_train_split
|
||||
)
|
||||
elif script_args.dataset_name == "ariG23498/intersection-dataset":
|
||||
# Format intersection dataset
|
||||
dataset = dataset.map(
|
||||
format_intersection_data, batched=True, batch_size=4, num_proc=4
|
||||
)
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=collate_fn,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split]
|
||||
if training_args.eval_strategy != "no"
|
||||
else None,
|
||||
processing_class=processor.tokenizer,
|
||||
peft_config=my_get_peft_config(model_args),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
if trainer.accelerator.is_main_process:
|
||||
processor.push_to_hub(training_args.hub_model_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,186 @@
|
||||
import os
|
||||
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from matplotlib import pyplot as plt
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import Gemma3nForConditionalGeneration, Gemma3nProcessor
|
||||
|
||||
|
||||
def collate_fn(examples, processor):
|
||||
messages = list()
|
||||
for sample in examples:
|
||||
image = sample["image"].convert("RGB")
|
||||
label = str(sample["label"])
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant with great geometry skills.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": image},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "How many intersection points are there in the image?",
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": label}]},
|
||||
]
|
||||
messages.append(message)
|
||||
|
||||
batch = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
labels = batch["input_ids"].clone() # Clone input IDs for labels
|
||||
# Mask the tokens that we do not want to include in the loss computation
|
||||
# -100 is ignored during categorical cross entropy loss computation
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == processor.tokenizer.image_token_id] = -100
|
||||
labels[labels == processor.tokenizer.boi_token_id] = -100
|
||||
labels[labels == processor.tokenizer.eoi_token_id] = -100
|
||||
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def freeze_layers(model):
|
||||
for name, param in model.named_parameters():
|
||||
if "attn" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
def run_inference(val_dataset, processor, model, fname):
|
||||
# infer before training
|
||||
val_sample = random.choice(val_dataset)
|
||||
image = val_sample["image"].convert("RGB")
|
||||
message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an assistant with great geometry skills.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": image},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "How many intersection points are there in the image?",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
message,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device, dtype=torch.bfloat16)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
with torch.no_grad():
|
||||
generation = model.generate(**inputs, max_new_tokens=10, disable_compile=True)
|
||||
generation = generation[0][input_len:]
|
||||
|
||||
decoded = processor.decode(generation, skip_special_tokens=True)
|
||||
|
||||
plt.imshow(image)
|
||||
plt.axis("off")
|
||||
plt.title(f"Pred: {decoded}")
|
||||
plt.show()
|
||||
plt.savefig(f"outputs_fine_tune/{fname}")
|
||||
|
||||
|
||||
def main():
|
||||
model_id = "google/gemma-3n-E2B-it"
|
||||
processor = Gemma3nProcessor.from_pretrained(model_id)
|
||||
|
||||
# load the dataset
|
||||
dataset_id = "ariG23498/intersection-dataset"
|
||||
train_dataset = load_dataset(dataset_id, split="train")
|
||||
val_dataset = load_dataset(dataset_id, split="validation")
|
||||
|
||||
# create data loader
|
||||
partial_collate_fn = partial(collate_fn, processor=processor)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
num_workers=8,
|
||||
drop_last=True,
|
||||
collate_fn=partial_collate_fn,
|
||||
pin_memory=True,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
drop_last=True,
|
||||
collate_fn=partial_collate_fn,
|
||||
)
|
||||
|
||||
# load the model and optimizer
|
||||
model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to("cuda")
|
||||
|
||||
run_inference(val_dataset, processor, model, "pred_before.png")
|
||||
|
||||
model = freeze_layers(model)
|
||||
|
||||
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
|
||||
optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)
|
||||
|
||||
# Start Training
|
||||
accumulation_steps = 8
|
||||
for idx, batch in tqdm(enumerate(train_dataloader)):
|
||||
outputs = model(**batch.to(model.device))
|
||||
loss = outputs.loss / accumulation_steps
|
||||
if idx % 50 == 0:
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
count = 0
|
||||
for val_batch in val_dataloader:
|
||||
val_loss = val_loss + model(**val_batch.to(model.device)).loss
|
||||
count = count + 1
|
||||
val_loss = val_loss / count
|
||||
print(
|
||||
f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
|
||||
)
|
||||
run_inference(val_dataset, processor, model, f"infer_{idx}.png")
|
||||
|
||||
loss.backward()
|
||||
if idx % 8 == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,425 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Gemma3n Fine-tuning on All Modalities.ipynb
|
||||
|
||||
Automatically generated by Colab.
|
||||
|
||||
Original file is located at
|
||||
https://colab.research.google.com/drive/1iEZUJuvKJpGU8t50BqfkiCQmGkaR6gd4
|
||||
|
||||
# Fine-tune Gemma3n on FineVideo
|
||||
|
||||
In this notebook, we will see how to fine-tune Gemma3n an videos with audios inside.
|
||||
Using all three modalities is very costly compute-wise, so keep in mind that this is an educational tutorial to fit the model in 40GB VRAM.
|
||||
"""
|
||||
|
||||
!pip install -U -q timm transformers trl peft datasets
|
||||
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
|
||||
|
||||
from trl import (
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
)
|
||||
|
||||
"""## Download videos and preprocessing
|
||||
|
||||
FineVideo is a quite large dataset, we don't need a ton of examples, so we stream the dataset, check the duration and download the videos shorter than 30 secs.
|
||||
"""
|
||||
|
||||
from datasets import load_dataset
|
||||
import json
|
||||
import os
|
||||
|
||||
dataset = load_dataset("HuggingFaceFV/finevideo", split="train", streaming=True)
|
||||
|
||||
|
||||
os.makedirs("videos", exist_ok=True)
|
||||
os.makedirs("metadata", exist_ok=True)
|
||||
|
||||
for idx, sample in enumerate(dataset):
|
||||
data = sample["json"]
|
||||
duration = data.get("duration_seconds", 0)
|
||||
if duration < 30:
|
||||
video_filename = f"videos/sample_{idx}.mp4"
|
||||
with open(video_filename, 'wb') as video_file:
|
||||
video_file.write(sample['mp4'])
|
||||
|
||||
json_filename = f"metadata/sample_{idx}.json"
|
||||
with open(json_filename, 'w') as json_file:
|
||||
json.dump(sample['json'], json_file)
|
||||
|
||||
print(f"Number of items in content/videos: {len(os.listdir('videos'))}")
|
||||
|
||||
"""In FineVideo some frames are dark so we downsample 6 frames and if we can't get meaningful videos we remove them."""
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
def is_dark(frame, threshold=10):
|
||||
return np.max(frame) < threshold # all pixels are very close to 0
|
||||
|
||||
def downsample_video(video_path):
|
||||
vidcap = cv2.VideoCapture(video_path)
|
||||
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
||||
|
||||
frames = []
|
||||
|
||||
# Generate 8 evenly spaced indices, skip first and last
|
||||
full_indices = np.linspace(0, total_frames - 1, 8, dtype=int)[1:-1]
|
||||
|
||||
for i in full_indices:
|
||||
found_valid = False
|
||||
for offset in [0, -1, 1, -2, 2]: # Try nearby frames if original is dark
|
||||
candidate_idx = i + offset
|
||||
if 0 <= candidate_idx < total_frames:
|
||||
vidcap.set(cv2.CAP_PROP_POS_FRAMES, candidate_idx)
|
||||
success, image = vidcap.read()
|
||||
if success:
|
||||
if not is_dark(image):
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image)
|
||||
timestamp = round(candidate_idx / fps, 2)
|
||||
frames.append((pil_image, timestamp))
|
||||
found_valid = True
|
||||
break
|
||||
if not found_valid:
|
||||
print(f"Warning: Could not find non-dark frame near index {i}")
|
||||
|
||||
vidcap.release()
|
||||
|
||||
# If still fewer than 8, try to top off by scanning more frames
|
||||
if len(frames) < 6:
|
||||
print("Trying to top off with additional non-dark frames...")
|
||||
idx = 0
|
||||
while len(frames) < 8 and idx < total_frames:
|
||||
vidcap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
||||
success, image = vidcap.read()
|
||||
if success and not is_dark(image):
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image)
|
||||
timestamp = round(idx / fps, 2)
|
||||
# Avoid adding duplicate timestamps
|
||||
if not any(ts == timestamp for _, ts in frames):
|
||||
frames.append((pil_image, timestamp))
|
||||
idx += 1
|
||||
|
||||
return frames[:8] # Ensure exactly 8 frames
|
||||
|
||||
import os
|
||||
import glob
|
||||
|
||||
def remove_dark_videos(video_dir, metadata_dir, audio_dir):
|
||||
"""
|
||||
Remove videos (and their metadata/audio files) if all frames are dark.
|
||||
"""
|
||||
video_paths = glob.glob(os.path.join(video_dir, "*.mp4"))
|
||||
|
||||
for video_path in video_paths:
|
||||
filename = os.path.basename(video_path)
|
||||
base_name = os.path.splitext(filename)[0]
|
||||
|
||||
frames = downsample_video(video_path)
|
||||
if len(frames) < 6:
|
||||
try:
|
||||
os.remove(video_path)
|
||||
print(f"Deleted: {video_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to delete {video_path}: {e}")
|
||||
|
||||
metadata_path = os.path.join(metadata_dir, f"{base_name}.json")
|
||||
if os.path.exists(metadata_path):
|
||||
os.remove(metadata_path)
|
||||
|
||||
# Remove audio
|
||||
audio_path = os.path.join(audio_dir, f"{base_name}.wav")
|
||||
if os.path.exists(audio_path):
|
||||
os.remove(audio_path)
|
||||
|
||||
remove_dark_videos(
|
||||
video_dir="videos",
|
||||
metadata_dir="metadata",
|
||||
audio_dir="audios"
|
||||
)
|
||||
|
||||
"""Gemma-3n accepts video (image frames) and audio separately, so we strip audio from video."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
video_dir = "videos"
|
||||
audio_dir = "audios"
|
||||
os.makedirs(audio_dir, exist_ok=True)
|
||||
|
||||
for filename in os.listdir(video_dir):
|
||||
if not filename.endswith(".mp4"):
|
||||
continue
|
||||
|
||||
idx = filename.split("_")[1].split(".")[0]
|
||||
video_path = os.path.join(video_dir, filename)
|
||||
audio_path = os.path.join(audio_dir, f"sample_{idx}.wav")
|
||||
|
||||
subprocess.run([
|
||||
"ffmpeg", "-i", video_path,
|
||||
"-q:a", "0", "-map", "a",
|
||||
audio_path,
|
||||
"-y"
|
||||
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
|
||||
"""Construct a new dataset with audio, video, metadata (video categories). This dataset is very cool, it has some questions and answers, captions and more so get creative if you have the GPU VRAM to do so. Here we solve an easier task for educational purposes."""
|
||||
|
||||
from datasets import Dataset
|
||||
import json
|
||||
|
||||
def gen():
|
||||
meta_dir = "metadata"
|
||||
for filename in os.listdir(meta_dir):
|
||||
if not filename.endswith(".json"):
|
||||
continue
|
||||
|
||||
idx = filename.split("_")[1].split(".")[0]
|
||||
if os.path.exists(f"videos/sample_{idx}.mp4"):
|
||||
video_filename = f"sample_{idx}.mp4"
|
||||
audio_filename = f"sample_{idx}.wav"
|
||||
json_path = os.path.join(meta_dir, filename)
|
||||
|
||||
with open(json_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
|
||||
yield {
|
||||
"video": video_filename,
|
||||
"audio": audio_filename,
|
||||
"content_parent_category": metadata["content_parent_category"],
|
||||
"sample_index": int(idx)
|
||||
}
|
||||
else:
|
||||
pass
|
||||
|
||||
dataset = Dataset.from_generator(gen)
|
||||
|
||||
"""We will speed-up and downsample the audios to save space during training."""
|
||||
|
||||
import torchaudio
|
||||
from torchaudio.transforms import Resample
|
||||
import os
|
||||
import torch
|
||||
|
||||
def preprocess_audio(audio_path, target_sample_rate=16000, max_duration_sec=5, speedup_factor=1.25):
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform.mean(dim=0, keepdim=True)
|
||||
|
||||
if sample_rate != target_sample_rate:
|
||||
resampler = Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
||||
waveform = resampler(waveform)
|
||||
sample_rate = target_sample_rate
|
||||
|
||||
if speedup_factor > 1.0:
|
||||
indices = torch.arange(0, waveform.shape[1], step=speedup_factor).long()
|
||||
if indices[-1] >= waveform.shape[1]:
|
||||
indices = indices[:-1]
|
||||
waveform = waveform[:, indices]
|
||||
|
||||
max_length = int(target_sample_rate * max_duration_sec)
|
||||
if waveform.shape[1] > max_length:
|
||||
waveform = waveform[:, :max_length]
|
||||
|
||||
torchaudio.save(audio_path, waveform, sample_rate)
|
||||
|
||||
for file_name in os.listdir("audios"):
|
||||
if file_name.lower().endswith(".wav"):
|
||||
audio_path = os.path.join("audios", file_name)
|
||||
preprocess_audio(audio_path)
|
||||
|
||||
dataset = dataset.train_test_split(test_size=0.10, seed=42)
|
||||
|
||||
"""### Load the model
|
||||
|
||||
Make sure you have your Hugging Face token in your Colab secrets.
|
||||
"""
|
||||
|
||||
model = Gemma3nForConditionalGeneration.from_pretrained(
|
||||
"google/gemma-3n-E2B-it", torch_dtype=torch.bfloat16,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
"google/gemma-3n-E2B-it",
|
||||
)
|
||||
processor.tokenizer.padding_side = "right"
|
||||
|
||||
processor.tokenizer.all_special_ids
|
||||
|
||||
"""Write our dataset collator. We will train model to predict category of a video (which can be done easily). You can do much better things, for instance FineVideo has QnA section, you can train this model to do open-ended QnA if you have a big VRAM and a lot of patience. Open-ended tasks are harder to work with, and this notebook carries educational purposes on feeding different modalities.
|
||||
|
||||
In collator we also downsample videos to 6 frames, we have written the helper above. For better results you need more frames.
|
||||
"""
|
||||
|
||||
def collate_fn(examples):
|
||||
video_path = examples[0]["video"]
|
||||
audio_path = examples[0]["audio"]
|
||||
sample_idx = filename.split("_")[1].split(".")[0]
|
||||
frames = downsample_video(f"videos/{video_path}")
|
||||
|
||||
text = "Based on the video, predict the category of it."
|
||||
message = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text}
|
||||
],
|
||||
},
|
||||
]
|
||||
# this is how video inference should be formatted in Gemma3n
|
||||
for frame in frames:
|
||||
image, timestamp = frame
|
||||
message[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
|
||||
timestamp = str(timestamp).replace(".", "_")
|
||||
image.save(f"image_idx_{sample_idx}_{timestamp}.png")
|
||||
message[0]["content"].append({"type": "image", "url": f"image_idx_{sample_idx}_{timestamp}.png"})
|
||||
|
||||
message[0]["content"].append({"type": "audio", "audio": f"audios/{audio_path}"})
|
||||
message.append({"role": "assistant", "content": [{"type": "text", "text": examples[0]["content_parent_category"]}]})
|
||||
inputs = processor.apply_chat_template(
|
||||
message,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(model.device)
|
||||
|
||||
labels = inputs["input_ids"].clone()
|
||||
special_token_ids = processor.tokenizer.all_special_ids
|
||||
|
||||
special_token_ids_tensor = torch.tensor(special_token_ids, device=labels.device)
|
||||
mask = torch.isin(labels, special_token_ids_tensor)
|
||||
labels[mask] = -100
|
||||
|
||||
inputs["labels"] = labels
|
||||
if torch.all(inputs["pixel_values"] == 0):
|
||||
print("Frames are dark")
|
||||
|
||||
return inputs
|
||||
|
||||
"""## Training
|
||||
|
||||
We do LoRA fine-tuning again to save up on space.
|
||||
"""
|
||||
|
||||
from peft import LoraConfig
|
||||
peft_config = LoraConfig(
|
||||
task_type="CAUSAL_LM",
|
||||
r=16,
|
||||
target_modules="all-linear",
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
use_rslora=False,
|
||||
use_dora=False,
|
||||
modules_to_save=None
|
||||
)
|
||||
|
||||
model.gradient_checkpointing_disable()
|
||||
|
||||
model.config.use_cache = False
|
||||
|
||||
training_args = SFTConfig(
|
||||
output_dir="/content/gemma-3n-finevideo",
|
||||
eval_strategy='epoch',
|
||||
per_device_train_batch_size=1,
|
||||
per_device_eval_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
gradient_checkpointing=False,
|
||||
learning_rate=1e-05,
|
||||
num_train_epochs=3.0,
|
||||
logging_steps=10,
|
||||
save_steps=100,
|
||||
bf16=True,
|
||||
report_to=["tensorboard"],
|
||||
dataset_kwargs={'skip_prepare_dataset': True},
|
||||
remove_unused_columns=False,
|
||||
max_seq_length=None,
|
||||
push_to_hub=True,
|
||||
dataloader_pin_memory=False,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=collate_fn,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"] if training_args.eval_strategy != "no" else None,
|
||||
processing_class=processor.tokenizer,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
"""Test the model with a video of snowboarding."""
|
||||
|
||||
!wget https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_8137.mp4
|
||||
|
||||
model = trainer.model # trainer has the adapter
|
||||
|
||||
"""Strip audio and downsample video."""
|
||||
|
||||
audio_path = "/content/test_audio.wav"
|
||||
subprocess.run([
|
||||
"ffmpeg", "-i", "/content/IMG_8137.mp4",
|
||||
"-q:a", "0", "-map", "a",
|
||||
f"{audio_path}",
|
||||
"-y"
|
||||
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
|
||||
frames = downsample_video("/content/IMG_8137.mp4")
|
||||
|
||||
# repeat the chat template
|
||||
text = "Based on the video, predict the category of it."
|
||||
message = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text}
|
||||
],
|
||||
},
|
||||
]
|
||||
for frame in frames:
|
||||
image, timestamp = frame
|
||||
message[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
|
||||
timestamp = str(timestamp).replace(".", "_")
|
||||
image.save(f"test_frame_{timestamp}.png")
|
||||
message[0]["content"].append({"type": "image", "url": f"test_frame_{timestamp}.png"})
|
||||
|
||||
message[0]["content"].append({"type": "audio", "audio": f"{audio_path}"})
|
||||
|
||||
message
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
message,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(model.device).to(model.dtype)
|
||||
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
with torch.inference_mode():
|
||||
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||||
generation = generation[0][input_len:]
|
||||
|
||||
decoded = processor.decode(generation, skip_special_tokens=True)
|
||||
print(decoded)
|
||||
|
||||
"""Thanks a lot for reading! Keep training the model further with more data or unfreeze the layers for better performance 💗"""
|
||||
|
||||
Reference in New Issue
Block a user