eecebe7ef5
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
185 lines
5.6 KiB
Python
185 lines
5.6 KiB
Python
import os
|
|
|
|
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
import random
|
|
from functools import partial
|
|
|
|
import torch
|
|
from datasets import load_dataset
|
|
from matplotlib import pyplot as plt
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
from transformers import Gemma3nForConditionalGeneration, Gemma3nProcessor
|
|
|
|
|
|
def collate_fn(examples, processor):
|
|
messages = list()
|
|
for sample in examples:
|
|
audio = sample["audio"]["array"]
|
|
label = str(sample["text"])
|
|
message = [
|
|
{
|
|
"role": "system",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "You are an assistant that transcribes speech accurately.",
|
|
}
|
|
],
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "audio", "audio": audio},
|
|
{"type": "text", "text": "Please transcribe this audio."},
|
|
],
|
|
},
|
|
{"role": "assistant", "content": [{"type": "text", "text": label}]},
|
|
]
|
|
messages.append(message)
|
|
|
|
batch = processor.apply_chat_template(
|
|
messages,
|
|
add_generation_prompt=False,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
labels = batch["input_ids"].clone() # Clone input IDs for labels
|
|
# Mask the tokens that we do not want to include in the loss computation
|
|
# -100 is ignored during categorical cross entropy loss computation
|
|
labels[labels == processor.tokenizer.pad_token_id] = -100
|
|
labels[labels == processor.tokenizer.audio_token_id] = -100
|
|
labels[labels == processor.tokenizer.image_token_id] = -100
|
|
labels[labels == processor.tokenizer.boi_token_id] = -100
|
|
labels[labels == processor.tokenizer.eoi_token_id] = -100
|
|
|
|
batch["labels"] = labels
|
|
|
|
return batch
|
|
|
|
|
|
def freeze_layers(model):
|
|
for name, param in model.named_parameters():
|
|
if "attn" in name:
|
|
param.requires_grad = True
|
|
else:
|
|
param.requires_grad = False
|
|
return model
|
|
|
|
|
|
def run_inference(val_dataset, processor, model, fname):
|
|
# infer before training
|
|
val_sample = random.choice(val_dataset)
|
|
audio = val_sample["audio"]["array"]
|
|
message = [
|
|
{
|
|
"role": "system",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "You are an assistant that transcribes speech accurately.",
|
|
}
|
|
],
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "audio", "audio": audio},
|
|
{"type": "text", "text": "Please transcribe this audio."},
|
|
],
|
|
},
|
|
]
|
|
inputs = processor.apply_chat_template(
|
|
message,
|
|
add_generation_prompt=True,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
).to(model.device, dtype=torch.bfloat16)
|
|
input_len = inputs["input_ids"].shape[-1]
|
|
with torch.no_grad():
|
|
generation = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
|
|
generation = generation[0][input_len:]
|
|
|
|
decoded = processor.decode(generation, skip_special_tokens=True)
|
|
|
|
print(f"Audio transcription: {decoded}")
|
|
print(f"Label: {val_sample['text']}")
|
|
|
|
|
|
def main():
|
|
model_id = "google/gemma-3n-E2B-it"
|
|
processor = Gemma3nProcessor.from_pretrained(model_id)
|
|
|
|
# Load and split the dataset.
|
|
ds_full = load_dataset("AdrienB134/Emilia-dataset-french-split", split="fr")
|
|
split_ds = ds_full.train_test_split(test_size=0.1, seed=42)
|
|
train_dataset = split_ds["train"].select(range(10000))
|
|
val_dataset = split_ds["test"].select(range(100))
|
|
|
|
# create data loader
|
|
partial_collate_fn = partial(collate_fn, processor=processor)
|
|
train_dataloader = DataLoader(
|
|
train_dataset,
|
|
batch_size=1,
|
|
shuffle=True,
|
|
num_workers=8,
|
|
drop_last=True,
|
|
collate_fn=partial_collate_fn,
|
|
pin_memory=True,
|
|
)
|
|
val_dataloader = DataLoader(
|
|
val_dataset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
num_workers=8,
|
|
drop_last=True,
|
|
collate_fn=partial_collate_fn,
|
|
)
|
|
|
|
# load the model and optimizer
|
|
model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to(
|
|
"cuda", dtype=torch.bfloat16
|
|
)
|
|
|
|
run_inference(val_dataset, processor, model, "pred_before.png")
|
|
|
|
model = freeze_layers(model)
|
|
|
|
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
|
|
optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)
|
|
|
|
# Start Training
|
|
accumulation_steps = 8
|
|
for idx, batch in tqdm(enumerate(train_dataloader)):
|
|
outputs = model(**batch.to(model.device, dtype=torch.bfloat16))
|
|
loss = outputs.loss / accumulation_steps
|
|
if idx % 100 == 0:
|
|
val_loss = 0.0
|
|
with torch.no_grad():
|
|
count = 0
|
|
for val_batch in tqdm(val_dataloader, desc="Validation"):
|
|
val_loss = (
|
|
val_loss
|
|
+ model(**val_batch.to(model.device, dtype=torch.bfloat16)).loss
|
|
)
|
|
count = count + 1
|
|
val_loss = val_loss / count
|
|
print(
|
|
f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
|
|
)
|
|
run_inference(val_dataset, processor, model, f"infer_{idx}.png")
|
|
|
|
loss.backward()
|
|
if idx % 8 == 0:
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|