Files
Mortdecai eecebe7ef5 docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own
README indexing downloaded files with verified upstream sources.

- google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts,
  gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev
  HTML snapshots, Gemma 3 tech report
- huggingface/: 8 gemma-4-* model cards, chat-template .jinja files,
  tokenizer_config.json, transformers gemma4/ source, launch blog posts,
  official HF Spaces app.py
- inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI
  comparison, run_commands.sh with 8 working launches, 9 code snippets
- gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2,
  Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma)
- fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE),
  TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md

Findings that update earlier CORPUS_* docs are flagged in tooling/README.md
(not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch
abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM,
FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech
report PDF yet, no Gemma-4-generation specialized siblings yet.

Pre-commit secrets hook bypassed per user authorization — flagged "secrets"
are base64 notebook cell outputs and example Ed25519 keys in the HDP
agentic-security demo, not real credentials.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-18 12:24:48 -04:00

268 lines
9.7 KiB
Python

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import os
import random
import socket
import sys
from typing import List, Union
import numpy as np
import torch
import torch.multiprocessing
from gemma.config import GemmaConfig, get_model_config
from gemma.model_xla import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import gemma.xla_model_parallel as xla_model_parallel
USE_CUDA = os.environ.get('USE_CUDA', False)
if not USE_CUDA:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
else:
# Choose an available port.
with contextlib.closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
MASTER_PORT = str(s.getsockname()[1])
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
def generate(
i: int,
model_config: GemmaConfig,
ckpt_path: str,
prompts: List[str],
output_lens: List[int],
temperatures: Union[List[float], None],
top_ps: List[float],
top_ks: List[int],
seed: int
):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if USE_CUDA:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = MASTER_PORT
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl",
rank=int(os.environ.get("RANK", 0)),
world_size=int(os.environ.get("WORLD_SIZE", 1)))
xla_model_parallel.set_g_group()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device("cuda", local_rank)
torch.cuda.set_device(local_rank)
else:
device = xm.xla_device()
xm.set_rng_state(seed, device)
rank = xla_model_parallel.get_model_parallel_rank()
world_size = xla_model_parallel.get_model_parallel_world_size()
if rank > 0:
sys.stdout = open(os.devnull, 'w')
# build, load and compile model.
with _set_default_tensor_type(model_config.get_dtype()):
model = GemmaForCausalLM(model_config, world_size, rank, device)
model.load_weights(ckpt_path)
model = model.to(device).eval()
# create tokenizer.
tokenizer = Tokenizer(model_config.tokenizer)
prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts]
min_prompt_len = min(len(p) for p in prompt_tokens)
batch_size = len(prompts)
if temperatures is not None:
assert batch_size == len(temperatures)
assert batch_size == len(top_ps)
assert batch_size == len(top_ks)
max_seq_len = max([len(p) + o for p, o in zip(prompt_tokens, output_lens)])
assert max_seq_len <= model_config.max_position_embeddings
if model_config.num_key_value_heads < world_size:
assert world_size % model_config.num_key_value_heads == 0
n_local_heads = 1
else:
assert model_config.num_key_value_heads % world_size == 0
n_local_heads = model_config.num_key_value_heads // world_size
# build KV caches
kv_caches = []
for _ in range(model_config.num_hidden_layers):
k_cache = torch.zeros(
size=(batch_size, max_seq_len, n_local_heads,
model_config.head_dim),
dtype=model_config.get_dtype(),
device=device,
)
v_cache = torch.zeros(
size=(batch_size, max_seq_len, n_local_heads,
model_config.head_dim),
dtype=model_config.get_dtype(),
device=device,
)
kv_caches.append((k_cache, v_cache))
# prepare inputs
token_ids_tensor = torch.full((batch_size, max_seq_len),
tokenizer.pad_id,
dtype=torch.int64)
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
tokenizer.pad_id,
dtype=torch.int64)
for i, p in enumerate(prompt_tokens):
token_ids_tensor[i, :len(p)] = torch.tensor(p)
input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
p[:min_prompt_len])
token_ids_tensor = token_ids_tensor.to(device)
prompt_mask_tensor = token_ids_tensor != tokenizer.pad_id
input_token_ids_tensor = input_token_ids_tensor.to(device)
input_positions_tensor = torch.arange(0, min_prompt_len,
dtype=torch.int64).to(device)
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
-2.3819763e38).to(torch.float)
mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
temperatures_tensor = None if not temperatures else torch.FloatTensor(temperatures).to(device)
top_ps_tensor = torch.FloatTensor(top_ps).to(device)
top_ks_tensor = torch.LongTensor(top_ks).to(device)
output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device)
if not USE_CUDA:
xm.mark_step()
# Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output.
for i in range(max_seq_len - min_prompt_len):
next_token_ids, _ = model(
input_token_ids=input_token_ids_tensor,
input_positions=input_positions_tensor,
kv_write_indices=None,
kv_caches=kv_caches,
mask=curr_mask_tensor,
output_positions=output_positions_tensor,
temperatures=temperatures_tensor,
top_ps=top_ps_tensor,
top_ks=top_ks_tensor,
)
curr_prompt_mask = prompt_mask_tensor.index_select(
1, output_index).squeeze(dim=1)
curr_token_ids = token_ids_tensor.index_select(
1, output_index).squeeze(dim=1)
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
next_token_ids).unsqueeze(dim=1)
token_ids_tensor.index_copy_(1, output_index, output_token_ids)
input_token_ids_tensor = output_token_ids
input_positions_tensor = output_index.unsqueeze(dim=-1)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device)
output_index = output_index + 1
if not USE_CUDA:
xm.mark_step()
# Detokenization.
token_ids = token_ids_tensor.tolist()
results = []
for i, tokens in enumerate(token_ids):
trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i]) +
output_lens[i]]
if tokenizer.eos_id in trimmed_output:
eos_index = trimmed_output.index(tokenizer.eos_id)
trimmed_output = trimmed_output[:eos_index]
results.append(tokenizer.decode(trimmed_output))
for prompt, result in zip(prompts, results):
print('======================================')
print(f'PROMPT: {prompt}')
print(f'RESULT: {result}')
print('======================================')
def main(args):
model_config = get_model_config(args.variant)
model_config.quant = args.quant
prompts = [args.prompt]
n = len(prompts)
output_lengths = [args.output_len] * n
temperatures = [0.95] * n
top_ps = [1.0] * n
top_ks = [100] * n
if USE_CUDA:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = MASTER_PORT
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl",
rank=int(os.environ.get("RANK", 0)),
world_size=int(os.environ.get("WORLD_SIZE", 1)))
xla_model_parallel.set_g_group()
torch.multiprocessing.spawn(
generate,
args=(
model_config,
args.ckpt,
prompts,
output_lengths,
temperatures,
top_ps,
top_ks,
args.seed,
),
)
else:
xmp.spawn(
generate,
args=(
model_config,
args.ckpt,
prompts,
output_lengths,
temperatures,
top_ps,
top_ks,
args.seed,
),
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--variant",
type=str,
default="2b",
choices=["2b", "2b-v2", "7b", "9b", "27b"])
parser.add_argument("--output_len", type=int, default=4)
parser.add_argument("--seed", type=int, default=12345)
parser.add_argument("--quant", action='store_true')
parser.add_argument("--prompt", type=str, default="The meaning of life is")
args = parser.parse_args()
main(args)