docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
###############################################################################################
|
||||
# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py #
|
||||
###############################################################################################
|
||||
@@ -0,0 +1,320 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl[peft]",
|
||||
# "trackio",
|
||||
# "kernels",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
# Full training
|
||||
```
|
||||
python examples/scripts/grpo_agent.py \
|
||||
--model_name_or_path Qwen/Qwen3-1.7B \
|
||||
--output_dir grpo_biogrid_qwen_3g-1.7b \
|
||||
--push_to_hub True \
|
||||
--use_vllm True \
|
||||
--vllm_mode colocate \
|
||||
--max_completion_length 1024 \
|
||||
--report_to trackio \
|
||||
--log_completions True \
|
||||
--max_steps 400
|
||||
```
|
||||
"""
|
||||
|
||||
import re
|
||||
import signal
|
||||
import sqlite3
|
||||
import textwrap
|
||||
from contextlib import contextmanager
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser
|
||||
|
||||
|
||||
def query_reward(completions, answer, **kwargs):
|
||||
"""
|
||||
Reward query strategy:
|
||||
- Penalize more than 2 queries
|
||||
- Penalize generic queries (LIMIT 1 / PRAGMA)
|
||||
- Reward usage of WHERE
|
||||
- Reward evidence supporting the final answer
|
||||
"""
|
||||
rewards = []
|
||||
|
||||
for completion, ans in zip(completions, answer, strict=False):
|
||||
reward = 0.0
|
||||
sql_queries = []
|
||||
tool_results = []
|
||||
|
||||
# collect all SQL queries and tool results
|
||||
for turn in completion:
|
||||
if turn.get("tool_calls"):
|
||||
for call in turn["tool_calls"]:
|
||||
sql = call["function"]["arguments"].get("sql_command", "").lower()
|
||||
sql_queries.append(sql)
|
||||
if turn.get("role") == "tool" and turn.get("content"):
|
||||
tool_results.append(turn["content"])
|
||||
|
||||
# --- penalize too many queries ---
|
||||
if len(sql_queries) > 3:
|
||||
reward -= 1.5
|
||||
|
||||
# --- check query quality ---
|
||||
where_count = 0
|
||||
for q in sql_queries:
|
||||
if "limit 1" in q:
|
||||
reward -= 1.0
|
||||
if " where " not in q:
|
||||
reward -= 0.5
|
||||
else:
|
||||
where_count += 1
|
||||
reward += min(where_count, 3) * 0.4 # small bonus for WHERE usage
|
||||
|
||||
# --- evidence check: do queries support the answer? ---
|
||||
combined_results = []
|
||||
error_detected = False
|
||||
|
||||
for res in tool_results:
|
||||
if isinstance(res, dict) and "error" in res:
|
||||
error_detected = True
|
||||
elif isinstance(res, list):
|
||||
combined_results.extend(res)
|
||||
|
||||
# if error detected, penalize heavily
|
||||
if error_detected:
|
||||
reward -= 2.0
|
||||
elif len(sql_queries) == 0:
|
||||
reward -= 1.5
|
||||
else:
|
||||
has_hits = len(combined_results) > 0
|
||||
correct_answer = ans.lower()
|
||||
if (has_hits and correct_answer == "yes") or (not has_hits and correct_answer == "no"):
|
||||
reward += 2.0
|
||||
else:
|
||||
reward -= 1.5
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def correctness_reward(completions, answer, **kwargs):
|
||||
"""
|
||||
Reward Yes/No correctness.
|
||||
Model must provide final answer enclosed in stars — *yes* or *no*.
|
||||
Does not reward informal yes/no buried in text.
|
||||
"""
|
||||
rewards = []
|
||||
for completion, ans in zip(completions, answer, strict=False):
|
||||
raw = completion[-1]["content"].lower()
|
||||
|
||||
# detect form *yes* or *no*
|
||||
match = re.search(r"\*(yes|no)\*", raw)
|
||||
guess = match.group(1) if match else None
|
||||
|
||||
reward = 0.0
|
||||
|
||||
if guess is None:
|
||||
reward -= 0.5 # invalid format
|
||||
elif guess == ans.lower():
|
||||
reward += 0.6 # correct under required format
|
||||
else:
|
||||
reward -= 1.0 # wrong answer
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def structure_reward(completions, **kwargs):
|
||||
"""
|
||||
Reward proper assistant structure.
|
||||
Encourages a logical sequence: tool call + response + optional extra content.
|
||||
"""
|
||||
rewards = []
|
||||
|
||||
for completion in completions:
|
||||
has_call = False
|
||||
has_response = False
|
||||
has_other = False
|
||||
|
||||
for turn in completion:
|
||||
role = turn.get("role")
|
||||
if role == "assistant" and turn.get("tool_calls"):
|
||||
has_call = True
|
||||
elif role == "tool":
|
||||
has_response = True
|
||||
else:
|
||||
content = turn.get("content")
|
||||
if content and content.strip() not in ["", "<think>"]:
|
||||
has_other = True
|
||||
|
||||
# Reward sequences
|
||||
if has_call and has_response:
|
||||
if has_other:
|
||||
reward = 0.1
|
||||
else:
|
||||
reward = 0.05 # still positive even without extra text
|
||||
elif has_call and not has_response:
|
||||
reward = -0.15
|
||||
else:
|
||||
reward = 0.0 # neutral if no call
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
# ------------------------
|
||||
# Database tool function
|
||||
# ------------------------
|
||||
class TimeoutError(Exception):
|
||||
"""Raised when a function call times out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timeout(seconds):
|
||||
"""Context manager that raises TimeoutError if execution exceeds time limit."""
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutError(f"Operation timed out after {seconds} seconds")
|
||||
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(seconds)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
def query_biogrid(sql_command: str) -> list[tuple]:
|
||||
"""
|
||||
Execute a read-only SQL command on the BioGRID database.
|
||||
|
||||
BioGRID is a curated biological database that compiles protein, genetic, and chemical interactions from multiple organisms. It provides researchers with experimentally verified interaction data to support studies in systems biology and functional genomics.
|
||||
|
||||
Args:
|
||||
sql_command: The SQL command to execute.
|
||||
|
||||
Returns:
|
||||
A list of tuples containing the query results.
|
||||
"""
|
||||
with timeout(5):
|
||||
conn = sqlite3.connect("file:biogrid.db?mode=ro", uri=True)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(sql_command)
|
||||
results = cursor.fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
return results
|
||||
|
||||
|
||||
# ------------------------
|
||||
# Dataset formatting
|
||||
# ------------------------
|
||||
def format_example(example):
|
||||
question = example["question"]
|
||||
preamble = textwrap.dedent("""\
|
||||
You have access to the BioGRID SQLite database.
|
||||
Use SQL queries to retrieve only the information needed to answer the question.
|
||||
|
||||
Genes may appear in the database in columns `Alt_IDs_Interactor_A` `Alt_IDs_Interactor_B`, `Aliases_Interactor_A` and `Aliases_Interactor_B`,
|
||||
and each entry can contain multiple gene names or synonyms separated by '|', for example:
|
||||
'entrez gene/locuslink:JNKK(gene name synonym)|entrez gene/locuslink:MAPKK4(gene name synonym)|...'
|
||||
So a gene like 'JNKK' or 'MAPKK4' may appear inside one of these strings.
|
||||
|
||||
If the database schema is unclear or you are unsure about column names:
|
||||
- First inspect the schema with `PRAGMA table_info(interactions);`
|
||||
- Or preview a few rows with `SELECT * FROM interactions LIMIT 1;`
|
||||
|
||||
Otherwise, directly query the required data.
|
||||
|
||||
Final answer must be enclosed in stars, e.g. *Yes* or *No*.
|
||||
Facts:
|
||||
- The NCBI Taxonomy identifier for humans is taxid:9606.
|
||||
""")
|
||||
content = f"{preamble}\nQuestion: {question}"
|
||||
prompt = [{"role": "user", "content": content}]
|
||||
return {"prompt": prompt}
|
||||
|
||||
|
||||
# ------------------------
|
||||
# Main
|
||||
# ------------------------
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
|
||||
# ------------------------
|
||||
# Create DB
|
||||
# ------------------------
|
||||
print("Creating biogrid.db...")
|
||||
# Load dataset
|
||||
biogrid_dataset = load_dataset("qgallouedec/biogrid", split="train")
|
||||
df = biogrid_dataset.to_pandas()
|
||||
|
||||
# Normalize column names: remove spaces, replace with underscores
|
||||
df.columns = [c.replace(" ", "_") for c in df.columns]
|
||||
conn = sqlite3.connect("biogrid.db")
|
||||
try:
|
||||
df.to_sql("interactions", conn, if_exists="replace", index=False)
|
||||
print(f"biogrid.db created. Rows stored: {len(df)}")
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ------------------------
|
||||
# Load and format dataset
|
||||
# ------------------------
|
||||
dataset = load_dataset("qgallouedec/biogrid_qa", split="train")
|
||||
dataset = dataset.filter(
|
||||
lambda example: example["question"].startswith("Does the gene ")
|
||||
) # keep only simple questions for example
|
||||
dataset = dataset.map(format_example, remove_columns=["question"])
|
||||
|
||||
train_dataset = dataset
|
||||
eval_dataset = None # No eval by default, can be added if needed
|
||||
|
||||
training_args.chat_template_kwargs = {"enable_thinking": False}
|
||||
|
||||
# ------------------------
|
||||
# Initialize trainer
|
||||
# ------------------------
|
||||
trainer = GRPOTrainer(
|
||||
model=model_args.model_name_or_path,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tools=[query_biogrid],
|
||||
reward_funcs=[correctness_reward, structure_reward, query_reward],
|
||||
args=training_args,
|
||||
)
|
||||
|
||||
# ------------------------
|
||||
# Train
|
||||
# ------------------------
|
||||
trainer.train()
|
||||
|
||||
# ------------------------
|
||||
# Save and push
|
||||
# ------------------------
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
@@ -0,0 +1,157 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl[peft]",
|
||||
# "Pillow",
|
||||
# "math-verify",
|
||||
# "latex2sympy2_extended",
|
||||
# "torchvision",
|
||||
# "trackio",
|
||||
# "kernels",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
pip install math_verify
|
||||
|
||||
# For Qwen/Qwen2.5-VL-3B-Instruct
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/grpo_vlm.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--output_dir grpo-Qwen2.5-VL-3B-Instruct \
|
||||
--learning_rate 1e-5 \
|
||||
--dtype bfloat16 \
|
||||
--max_completion_length 1024 \
|
||||
--use_vllm \
|
||||
--vllm_mode colocate \
|
||||
--use_peft \
|
||||
--lora_target_modules "q_proj", "v_proj" \
|
||||
--log_completions
|
||||
|
||||
# For HuggingFaceTB/SmolVLM2-2.2B-Instruct
|
||||
pip install num2words==0.5.14
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/grpo_vlm.py \
|
||||
--model_name_or_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \
|
||||
--output_dir grpo-SmolVLM2-2.2B-Instruct \
|
||||
--learning_rate 1e-5 \
|
||||
--dtype bfloat16 \
|
||||
--max_completion_length 1024 \
|
||||
--use_peft \
|
||||
--lora_target_modules "q_proj", "v_proj" \
|
||||
--log_completions \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_generations 2
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from trl import (
|
||||
GRPOConfig,
|
||||
GRPOTrainer,
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
################
|
||||
# Model
|
||||
################
|
||||
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
|
||||
training_args.model_init_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
dtype=dtype,
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
if quantization_config is not None:
|
||||
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
|
||||
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
|
||||
training_args.model_init_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
dataset = load_dataset("lmms-lab/multimodal-open-r1-8k-verified", split="train")
|
||||
dataset = dataset.train_test_split(test_size=100, seed=42)
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"A conversation between user and assistant. The user asks a question, and the assistant solves it. The "
|
||||
"assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
|
||||
"The reasoning process and answer are enclosed within <think></think> tags, i.e., <think>\nThis is my "
|
||||
"reasoning.\n</think>\nThis is my answer."
|
||||
)
|
||||
|
||||
def make_conversation(example):
|
||||
prompt = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": example["problem"]},
|
||||
]
|
||||
return {"prompt": prompt}
|
||||
|
||||
dataset = dataset.map(make_conversation)
|
||||
|
||||
# Filter have big images
|
||||
def filter_big_images(example):
|
||||
image = example["image"]
|
||||
return image.size[0] < 512 and image.size[1] < 512
|
||||
|
||||
dataset = dataset.filter(filter_big_images)
|
||||
|
||||
def convert_to_rgb(example):
|
||||
image = example["image"]
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
example["image"] = image
|
||||
return example
|
||||
|
||||
dataset = dataset.map(convert_to_rgb)
|
||||
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = GRPOTrainer(
|
||||
model=model_args.model_name_or_path,
|
||||
args=training_args,
|
||||
reward_funcs=[think_format_reward, accuracy_reward],
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
###############################################################################################
|
||||
# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py #
|
||||
###############################################################################################
|
||||
@@ -0,0 +1,69 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl",
|
||||
# "Pillow",
|
||||
# "trackio",
|
||||
# "kernels",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Train Gemma-3 on the Codeforces COTS dataset.
|
||||
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/sft_gemma3.py
|
||||
"""
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
|
||||
def main():
|
||||
# Load dataset
|
||||
train_dataset = load_dataset("open-r1/codeforces-cots", split="train")
|
||||
train_dataset = train_dataset.remove_columns("prompt")
|
||||
|
||||
# Load model
|
||||
model_id = "google/gemma-3-12b-it"
|
||||
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager")
|
||||
|
||||
# Train model
|
||||
training_args = SFTConfig(
|
||||
output_dir=f"{model_id}-codeforces-SFT",
|
||||
bf16=True,
|
||||
use_liger_kernel=True,
|
||||
max_length=8192,
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=8,
|
||||
dataset_num_proc=32,
|
||||
num_train_epochs=1,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Push to hub
|
||||
trainer.push_to_hub(dataset_name="open-r1/codeforces-cots")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,164 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl[peft]",
|
||||
# "bitsandbytes",
|
||||
# "liger-kernel",
|
||||
# "trackio",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Teach tool calling to CohereLabs/tiny-aya-global using SFT with QLoRA on the bebechien/SimpleToolCalling dataset.
|
||||
|
||||
The model used in this script does not have native tool-calling support. We extend its existing Jinja2 chat template to
|
||||
serialize tool schemas into the system preamble and render tool calls as structured <tool_call> XML inside the model's
|
||||
native <|START_RESPONSE|> / <|END_RESPONSE|> delimiters. The modified template is saved with the tokenizer, so
|
||||
inference only requires loading the tokenizer from the output directory and calling apply_chat_template with
|
||||
tools=TOOLS — no manual system-prompt construction needed.
|
||||
|
||||
Example:
|
||||
|
||||
python examples/scripts/sft_tiny_aya_tool_calling.py
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
|
||||
# These are the tool schemas that are used in the dataset
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_knowledge_base",
|
||||
"description": "Search internal company documents, policies and project data.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string", "description": "query string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
"return": {"type": "string"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_google",
|
||||
"description": "Search public information.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string", "description": "query string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
"return": {"type": "string"},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def create_conversation(sample):
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": sample["user_content"]}],
|
||||
"completion": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": sample["tool_name"],
|
||||
"arguments": json.loads(sample["tool_arguments"]),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
"tools": TOOLS,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
model_id = "CohereLabs/tiny-aya-global"
|
||||
dataset_name = "bebechien/SimpleToolCalling"
|
||||
output_dir = "tiny-aya-global-tool-calling-SFT"
|
||||
|
||||
# Load and format dataset
|
||||
dataset = load_dataset(dataset_name, split="train")
|
||||
dataset = dataset.map(create_conversation, remove_columns=dataset.features)
|
||||
dataset = dataset.train_test_split(test_size=0.5, shuffle=True)
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="sdpa",
|
||||
dtype=torch.float16,
|
||||
quantization_config=BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
),
|
||||
)
|
||||
|
||||
# Configure LoRA
|
||||
peft_config = LoraConfig(
|
||||
r=32,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
)
|
||||
|
||||
# Train
|
||||
training_args = SFTConfig(
|
||||
output_dir=output_dir,
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
# Use the tool-aware chat template
|
||||
chat_template_path=str(Path(__file__).parent / "tiny_aya_chat_template.jinja"),
|
||||
warmup_steps=5,
|
||||
learning_rate=2e-4,
|
||||
optim="paged_adamw_8bit",
|
||||
logging_steps=1,
|
||||
report_to="trackio",
|
||||
trackio_space_id=output_dir,
|
||||
max_length=1024,
|
||||
use_liger_kernel=True,
|
||||
activation_offloading=True,
|
||||
push_to_hub=True,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset["train"],
|
||||
peft_config=peft_config,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Save model and tokenizer (tokenizer carries the updated chat template)
|
||||
trainer.save_model(output_dir)
|
||||
trainer.push_to_hub(dataset_name=dataset_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,117 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl[peft]",
|
||||
# "Pillow>=9.4.0",
|
||||
# "trackio",
|
||||
# "kernels",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
pip install pillow
|
||||
|
||||
# Tested on 8x H100 GPUs
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/sft_vlm.py \
|
||||
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
|
||||
--model_name_or_path llava-hf/llava-1.5-7b-hf \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--output_dir LLaVA-1.5-7B-SFT \
|
||||
--dtype bfloat16
|
||||
|
||||
For LLaVA-NeXT, use:
|
||||
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf
|
||||
|
||||
For meta-llama/Llama-3.2-11B-Vision-Instruct, use:
|
||||
--model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/sft_vlm.py \
|
||||
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
|
||||
--model_name_or_path HuggingFaceTB/SmolVLM-Instruct \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--output_dir SmolVLM-SFT \
|
||||
--dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules down_proj, o_proj, k_proj, q_proj, gate_proj, up_proj, v_proj
|
||||
"""
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
training_args.max_length = None
|
||||
|
||||
################
|
||||
# Model
|
||||
################
|
||||
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
dtype=dtype,
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
if quantization_config is not None:
|
||||
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
|
||||
model_kwargs["device_map"] = get_kbit_device_map()
|
||||
model_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
|
||||
)
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
@@ -0,0 +1,189 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl[peft]",
|
||||
# "Pillow>=9.4.0",
|
||||
# "trackio",
|
||||
# "kernels",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Train Gemma 3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image).
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/sft_vlm_gemma3.py \
|
||||
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
|
||||
--model_name_or_path google/gemma-3-4b-it \
|
||||
--per_device_train_batch_size 1 \
|
||||
--output_dir Gemma-3-4B-SFT-MMIU \
|
||||
--dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules all-linear \
|
||||
--attn_implementation eager
|
||||
|
||||
Train Gemma 3 on the FanqingM/MMIU-Benchmark dataset (multi-image).
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/sft_vlm_gemma3.py \
|
||||
--dataset_name FanqingM/MMIU-Benchmark \
|
||||
--dataset_train_split test \
|
||||
--model_name_or_path google/gemma-3-4b-it \
|
||||
--per_device_train_batch_size 1 \
|
||||
--output_dir Gemma-3-4B-SFT-MMIU \
|
||||
--dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules all-linear \
|
||||
--attn_implementation eager
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
|
||||
# For multi-image example
|
||||
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
||||
image_inputs = []
|
||||
for msg in messages:
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
for element in content:
|
||||
if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element
|
||||
if image is not None:
|
||||
image = Image.open(io.BytesIO(image["bytes"]))
|
||||
image_inputs.append(image.convert("RGB"))
|
||||
return image_inputs
|
||||
|
||||
|
||||
def format_data(samples: dict[str, any]) -> dict[str, list]:
|
||||
formatted_samples = {"messages": []}
|
||||
for cont in range(len(samples["question"])):
|
||||
images = []
|
||||
for img_path in samples["input_image_path"][cont]:
|
||||
try:
|
||||
with open(img_path, "rb") as f:
|
||||
img_bytes = f.read()
|
||||
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
||||
images.append({"type": "image", "image": image})
|
||||
except Exception as e:
|
||||
print(f"Error processing image {img_path}: {e}")
|
||||
continue
|
||||
|
||||
formatted_samples["messages"].append(
|
||||
[
|
||||
{"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
|
||||
{"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
|
||||
]
|
||||
)
|
||||
return formatted_samples
|
||||
|
||||
|
||||
# For multi-image example
|
||||
def prepare_dataset(dataset: DatasetDict, dataset_name: str) -> DatasetDict:
|
||||
all_files = list_repo_files(dataset_name, repo_type="dataset")
|
||||
zip_files = [f for f in all_files if f.endswith(".zip")]
|
||||
|
||||
for zip_filename in zip_files:
|
||||
zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
|
||||
extract_folder = zip_filename.replace(".zip", "")
|
||||
os.makedirs(extract_folder, exist_ok=True)
|
||||
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_folder)
|
||||
|
||||
dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
training_args.max_length = None
|
||||
|
||||
################
|
||||
# Model
|
||||
################
|
||||
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
dtype=dtype,
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
if quantization_config is not None:
|
||||
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
|
||||
model_kwargs["device_map"] = get_kbit_device_map()
|
||||
model_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
|
||||
)
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
if script_args.dataset_name == "FanqingM/MMIU-Benchmark":
|
||||
dataset = prepare_dataset(dataset, script_args.dataset_name)
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,156 @@
|
||||
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl",
|
||||
# "peft",
|
||||
# "trackio",
|
||||
# "kernels",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
# Full training
|
||||
```
|
||||
python trl/scripts/sft.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B \
|
||||
--dataset_name trl-lib/Capybara \
|
||||
--learning_rate 2.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--eos_token '<|im_end|>' \
|
||||
--eval_strategy steps \
|
||||
--eval_steps 100 \
|
||||
--output_dir Qwen2-0.5B-SFT \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
# LoRA
|
||||
```
|
||||
python trl/scripts/sft.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B \
|
||||
--dataset_name trl-lib/Capybara \
|
||||
--learning_rate 2.0e-4 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--eos_token '<|im_end|>' \
|
||||
--eval_strategy steps \
|
||||
--eval_steps 100 \
|
||||
--use_peft \
|
||||
--lora_r 32 \
|
||||
--lora_alpha 16 \
|
||||
--output_dir Qwen2-0.5B-SFT \
|
||||
--push_to_hub
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args, dataset_args):
|
||||
from accelerate import logging
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
|
||||
|
||||
from trl import SFTTrainer, get_dataset, get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
################
|
||||
# Model init kwargs
|
||||
################
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
dtype=model_args.dtype,
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
if quantization_config is not None:
|
||||
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
|
||||
model_kwargs["device_map"] = get_kbit_device_map()
|
||||
model_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
# Create model
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
valid_image_text_architectures = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
|
||||
|
||||
if config.architectures and any(arch in valid_image_text_architectures for arch in config.architectures):
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
|
||||
|
||||
# Load the dataset
|
||||
if dataset_args.datasets and script_args.dataset_name:
|
||||
logger.warning(
|
||||
"Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
|
||||
"dataset and `dataset_name` will be ignored."
|
||||
)
|
||||
dataset = get_dataset(dataset_args)
|
||||
elif dataset_args.datasets and not script_args.dataset_name:
|
||||
dataset = get_dataset(dataset_args)
|
||||
elif not dataset_args.datasets and script_args.dataset_name:
|
||||
dataset = load_dataset(
|
||||
script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either `datasets` or `dataset_name` must be provided.")
|
||||
|
||||
# Initialize the SFT trainer
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Log training complete
|
||||
trainer.accelerator.print("✅ Training completed.")
|
||||
|
||||
# Save and push to Hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
trainer.accelerator.print(f"💾 Model saved to {training_args.output_dir}.")
|
||||
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
|
||||
|
||||
|
||||
def make_parser(subparsers: argparse._SubParsersAction | None = None, prog: str | None = None):
|
||||
from trl import DatasetMixtureConfig, ModelConfig, ScriptArguments, SFTConfig, TrlParser
|
||||
|
||||
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig, DatasetMixtureConfig)
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
|
||||
else:
|
||||
parser = TrlParser(dataclass_types, prog=prog)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = make_parser()
|
||||
script_args, training_args, model_args, dataset_args = parser.parse_args_and_config(fail_with_unknown_args=False)
|
||||
main(script_args, training_args, model_args, dataset_args)
|
||||
Reference in New Issue
Block a user