eecebe7ef5
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
from typing import Optional
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from ..base import InputEmbeddingsFeatures
|
|
from .audio import AudioEncoder
|
|
from .config import ModelConfig
|
|
from .language import LanguageModel, RMSNormNoScale
|
|
from .vision import VisionModel
|
|
|
|
|
|
def masked_scatter(input_tensor, mask, source):
|
|
mask_flat = mask.flatten().astype(mx.int32)
|
|
indices = mx.cumsum(mask_flat) - 1
|
|
aligned = source.flatten()[indices % source.size]
|
|
return mx.where(mask_flat, aligned, input_tensor.flatten()).reshape(
|
|
input_tensor.shape
|
|
)
|
|
|
|
|
|
class MultimodalEmbedder(nn.Module):
|
|
"""Projects soft tokens from vision/audio into language model space."""
|
|
|
|
def __init__(self, embedding_dim: int, text_hidden_size: int, eps: float = 1e-6):
|
|
super().__init__()
|
|
self.embedding_projection = nn.Linear(
|
|
embedding_dim, text_hidden_size, bias=False
|
|
)
|
|
self.embedding_pre_projection_norm = RMSNormNoScale(embedding_dim, eps=eps)
|
|
|
|
def __call__(self, inputs_embeds: mx.array) -> mx.array:
|
|
normed = self.embedding_pre_projection_norm(inputs_embeds)
|
|
return self.embedding_projection(normed)
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, config: ModelConfig):
|
|
super().__init__()
|
|
self.model_type = config.model_type
|
|
self.config = config
|
|
|
|
# Text
|
|
self.language_model = LanguageModel(config.text_config)
|
|
self.vocab_size = config.text_config.vocab_size
|
|
|
|
# Vision
|
|
self.vision_tower = VisionModel(config.vision_config)
|
|
self.embed_vision = MultimodalEmbedder(
|
|
embedding_dim=config.vision_config.hidden_size,
|
|
text_hidden_size=config.text_config.hidden_size,
|
|
eps=config.vision_config.rms_norm_eps,
|
|
)
|
|
|
|
# Audio
|
|
if config.audio_config is not None:
|
|
self.audio_tower = AudioEncoder(config.audio_config)
|
|
audio_output_dim = (
|
|
config.audio_config.output_proj_dims or config.audio_config.hidden_size
|
|
)
|