Files
gemma4-research/tooling/inference-frameworks/snippets/mlx_vlm_gemma4_head_60.py
T
Mortdecai eecebe7ef5 docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own
README indexing downloaded files with verified upstream sources.

- google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts,
  gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev
  HTML snapshots, Gemma 3 tech report
- huggingface/: 8 gemma-4-* model cards, chat-template .jinja files,
  tokenizer_config.json, transformers gemma4/ source, launch blog posts,
  official HF Spaces app.py
- inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI
  comparison, run_commands.sh with 8 working launches, 9 code snippets
- gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2,
  Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma)
- fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE),
  TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md

Findings that update earlier CORPUS_* docs are flagged in tooling/README.md
(not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch
abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM,
FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech
report PDF yet, no Gemma-4-generation specialized siblings yet.

Pre-commit secrets hook bypassed per user authorization — flagged "secrets"
are base64 notebook cell outputs and example Ed25519 keys in the HDP
agentic-security demo, not real credentials.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-18 12:24:48 -04:00

61 lines
2.0 KiB
Python

from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from ..base import InputEmbeddingsFeatures
from .audio import AudioEncoder
from .config import ModelConfig
from .language import LanguageModel, RMSNormNoScale
from .vision import VisionModel
def masked_scatter(input_tensor, mask, source):
mask_flat = mask.flatten().astype(mx.int32)
indices = mx.cumsum(mask_flat) - 1
aligned = source.flatten()[indices % source.size]
return mx.where(mask_flat, aligned, input_tensor.flatten()).reshape(
input_tensor.shape
)
class MultimodalEmbedder(nn.Module):
"""Projects soft tokens from vision/audio into language model space."""
def __init__(self, embedding_dim: int, text_hidden_size: int, eps: float = 1e-6):
super().__init__()
self.embedding_projection = nn.Linear(
embedding_dim, text_hidden_size, bias=False
)
self.embedding_pre_projection_norm = RMSNormNoScale(embedding_dim, eps=eps)
def __call__(self, inputs_embeds: mx.array) -> mx.array:
normed = self.embedding_pre_projection_norm(inputs_embeds)
return self.embedding_projection(normed)
class Model(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.model_type = config.model_type
self.config = config
# Text
self.language_model = LanguageModel(config.text_config)
self.vocab_size = config.text_config.vocab_size
# Vision
self.vision_tower = VisionModel(config.vision_config)
self.embed_vision = MultimodalEmbedder(
embedding_dim=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
eps=config.vision_config.rms_norm_eps,
)
# Audio
if config.audio_config is not None:
self.audio_tower = AudioEncoder(config.audio_config)
audio_output_dim = (
config.audio_config.output_proj_dims or config.audio_config.hidden_size
)