Files
Mortdecai eecebe7ef5 docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own
README indexing downloaded files with verified upstream sources.

- google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts,
  gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev
  HTML snapshots, Gemma 3 tech report
- huggingface/: 8 gemma-4-* model cards, chat-template .jinja files,
  tokenizer_config.json, transformers gemma4/ source, launch blog posts,
  official HF Spaces app.py
- inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI
  comparison, run_commands.sh with 8 working launches, 9 code snippets
- gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2,
  Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma)
- fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE),
  TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md

Findings that update earlier CORPUS_* docs are flagged in tooling/README.md
(not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch
abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM,
FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech
report PDF yet, no Gemma-4-generation specialized siblings yet.

Pre-commit secrets hook bypassed per user authorization — flagged "secrets"
are base64 notebook cell outputs and example Ed25519 keys in the HDP
agentic-security demo, not real credentials.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-18 12:24:48 -04:00

93 lines
2.4 KiB
Python

# Copyright © 2025 Apple Inc.
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten
from . import gemma4_text
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "gemma4"
text_config: dict = None
vocab_size: int = 262144
def __post_init__(self):
if self.text_config is None:
self.text_config = {}
self.text_config["vocab_size"] = self.vocab_size
self.text_config["num_attention_heads"] = self.text_config.get(
"num_attention_heads", 8
)
self.text_config["num_key_value_heads"] = self.text_config.get(
"num_key_value_heads", 1
)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.language_model = gemma4_text.Model(
gemma4_text.ModelArgs.from_dict(args.text_config)
)
def __call__(
self,
inputs: mx.array,
cache=None,
input_embeddings: Optional[mx.array] = None,
per_layer_inputs: Optional[mx.array] = None,
):
return self.language_model(
inputs,
cache=cache,
input_embeddings=input_embeddings,
per_layer_inputs=per_layer_inputs,
)
def sanitize(self, weights):
new_weights = {}
for k, v in weights.items():
starts_w_model = k.startswith("model.")
k = k.removeprefix("model.")
if k.startswith(
(
"vision_tower",
"multi_modal_projector",
"audio_tower",
"embed_audio",
"embed_vision",
)
):
continue
if not starts_w_model:
new_weights[k] = v
continue
if k.startswith("language_model"):
k = k.replace("language_model.", "language_model.model.")
new_weights[k] = v
return self.language_model.sanitize(new_weights)
@property
def layers(self):
return self.language_model.layers
@property
def quant_predicate(self):
return self.language_model.quant_predicate
def make_cache(self):
return self.language_model.make_cache()