eecebe7ef5
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
93 lines
2.4 KiB
Python
93 lines
2.4 KiB
Python
# Copyright © 2025 Apple Inc.
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from mlx.utils import tree_flatten, tree_unflatten
|
|
|
|
from . import gemma4_text
|
|
from .base import BaseModelArgs
|
|
|
|
|
|
@dataclass
|
|
class ModelArgs(BaseModelArgs):
|
|
model_type: str = "gemma4"
|
|
text_config: dict = None
|
|
vocab_size: int = 262144
|
|
|
|
def __post_init__(self):
|
|
if self.text_config is None:
|
|
self.text_config = {}
|
|
self.text_config["vocab_size"] = self.vocab_size
|
|
self.text_config["num_attention_heads"] = self.text_config.get(
|
|
"num_attention_heads", 8
|
|
)
|
|
self.text_config["num_key_value_heads"] = self.text_config.get(
|
|
"num_key_value_heads", 1
|
|
)
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
self.model_type = args.model_type
|
|
self.language_model = gemma4_text.Model(
|
|
gemma4_text.ModelArgs.from_dict(args.text_config)
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: mx.array,
|
|
cache=None,
|
|
input_embeddings: Optional[mx.array] = None,
|
|
per_layer_inputs: Optional[mx.array] = None,
|
|
):
|
|
return self.language_model(
|
|
inputs,
|
|
cache=cache,
|
|
input_embeddings=input_embeddings,
|
|
per_layer_inputs=per_layer_inputs,
|
|
)
|
|
|
|
def sanitize(self, weights):
|
|
new_weights = {}
|
|
for k, v in weights.items():
|
|
starts_w_model = k.startswith("model.")
|
|
|
|
k = k.removeprefix("model.")
|
|
if k.startswith(
|
|
(
|
|
"vision_tower",
|
|
"multi_modal_projector",
|
|
"audio_tower",
|
|
"embed_audio",
|
|
"embed_vision",
|
|
)
|
|
):
|
|
continue
|
|
|
|
if not starts_w_model:
|
|
new_weights[k] = v
|
|
continue
|
|
|
|
if k.startswith("language_model"):
|
|
k = k.replace("language_model.", "language_model.model.")
|
|
|
|
new_weights[k] = v
|
|
|
|
return self.language_model.sanitize(new_weights)
|
|
|
|
@property
|
|
def layers(self):
|
|
return self.language_model.layers
|
|
|
|
@property
|
|
def quant_predicate(self):
|
|
return self.language_model.quant_predicate
|
|
|
|
def make_cache(self):
|
|
return self.language_model.make_cache()
|