Files
Mortdecai eecebe7ef5 docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own
README indexing downloaded files with verified upstream sources.

- google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts,
  gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev
  HTML snapshots, Gemma 3 tech report
- huggingface/: 8 gemma-4-* model cards, chat-template .jinja files,
  tokenizer_config.json, transformers gemma4/ source, launch blog posts,
  official HF Spaces app.py
- inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI
  comparison, run_commands.sh with 8 working launches, 9 code snippets
- gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2,
  Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma)
- fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE),
  TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md

Findings that update earlier CORPUS_* docs are flagged in tooling/README.md
(not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch
abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM,
FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech
report PDF yet, no Gemma-4-generation specialized siblings yet.

Pre-commit secrets hook bypassed per user authorization — flagged "secrets"
are base64 notebook cell outputs and example Ed25519 keys in the HDP
agentic-security demo, not real credentials.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-18 12:24:48 -04:00

564 lines
21 KiB
Python

# === HEADER (license + imports) ===
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property
import torch
from torch import nn
from torch.nn import functional as F
from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PreTrainedConfig
from ...integrations import use_kernelized_func
from ...masking_utils import (
create_bidirectional_mask,
create_causal_mask,
create_masks_for_generate,
create_sliding_window_causal_mask,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
TransformersKwargs,
auto_docstring,
can_return_tuple,
is_accelerate_available,
logging,
torch_compilable_check,
)
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ..auto.modeling_auto import AutoModel
from ..gemma3.modeling_gemma3 import (
Gemma3Attention,
Gemma3DecoderLayer,
Gemma3ForCausalLM,
Gemma3MLP,
Gemma3RotaryEmbedding,
Gemma3TextModel,
Gemma3TextScaledWordEmbedding,
)
from ..gemma3n.modeling_gemma3n import (
Gemma3nCausalLMOutputWithPast,
Gemma3nForConditionalGeneration,
Gemma3nModel,
Gemma3nModelOutputWithPast,
Gemma3nMultimodalEmbedder,
Gemma3nPreTrainedModel,
Gemma3nRMSNorm,
apply_rotary_pos_emb,
eager_attention_forward,
)
from ..llama.modeling_llama import LlamaRotaryEmbedding
from ..mixtral.modeling_mixtral import MixtralExperts
from ..moonshine_streaming.modeling_moonshine_streaming import sliding_window_mask_function
from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig
if is_accelerate_available():
pass
# === CLASS/FUNCTION OUTLINE (signatures + short body) ===
class Gemma4ModelOutputWithPast(Gemma3nModelOutputWithPast):
pass
class Gemma4CausalLMOutputWithPast(Gemma3nCausalLMOutputWithPast):
pass
@dataclass
@auto_docstring
class Gemma4AudioModelOutput(BaseModelOutputWithPooling):
r"""
attention_mask (`torch.BoolTensor`, *optional*):
A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding.
...
class Gemma4ClippableLinear(nn.Module):
def __init__(
self,
config: Gemma4VisionConfig | Gemma4AudioConfig,
in_features: int,
out_features: int,
) -> None:
super().__init__()
self.use_clipped_linears = config.use_clipped_linears
self.linear = nn.Linear(in_features, out_features, bias=False)
if self.use_clipped_linears:
self.register_buffer("input_min", torch.tensor(-float("inf")))
self.register_buffer("input_max", torch.tensor(float("inf")))
...
class Gemma4RMSNorm(Gemma3nRMSNorm):
pass
class Gemma4AudioRelPositionalEncoding(nn.Module):
"""Sinusoidal relative positional encoding for the audio encoder.
Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with
concatenated [sin..., cos...] layout matching the original Gemma4 convention.
"""
inv_timescales: torch.Tensor
def __init__(self, config: Gemma4AudioConfig):
...
class Gemma4AudioAttention(nn.Module):
"""Chunked local attention with relative position bias"""
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_logits_soft_cap = config.attention_logit_cap
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.q_scale = (self.head_dim**-0.5) / math.log(2)
self.k_scale = math.log(1 + math.e) / math.log(2)
...
class Gemma4AudioSubSampleConvProjectionLayer(nn.Module):
def __init__(self, in_channels, out_channels, norm_eps):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(2, 2),
padding=1,
bias=False,
)
self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False)
self.act = nn.ReLU()
...
class Gemma4AudioSubSampleConvProjection(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.layer0 = Gemma4AudioSubSampleConvProjectionLayer(
in_channels=1,
out_channels=config.subsampling_conv_channels[0],
norm_eps=config.rms_norm_eps,
)
self.layer1 = Gemma4AudioSubSampleConvProjectionLayer(
in_channels=config.subsampling_conv_channels[0],
out_channels=config.subsampling_conv_channels[1],
norm_eps=config.rms_norm_eps,
)
proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1]
...
class Gemma4AudioFeedForward(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.config = config
self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4)
self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size)
self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size)
self.post_layer_norm = Gemma4RMSNorm(config.hidden_size)
self.act_fn = ACT2FN[config.hidden_act]
self.gradient_clipping = config.gradient_clipping
self.post_layer_scale = config.residual_weight
...
class Gemma4AudioCausalConv1d(nn.Conv1d):
# def __init__(
# self,
# in_channels: int,
# out_channels: int,
# kernel_size: int,
# # cache_key: str,
# stride: int = 1,
# dilation: int = 1,
# bias: bool = True,
# ):
# super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias)
# self.cache_key = cache_key
...
class Gemma4AudioLightConv1d(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.config = config
self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2)
self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
self.depthwise_conv1d = Gemma4AudioCausalConv1d(
in_channels=config.hidden_size,
out_channels=config.hidden_size,
kernel_size=config.conv_kernel_size,
groups=config.hidden_size,
bias=False,
)
...
class Gemma4AudioLayer(nn.Module):
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
super().__init__()
self.config = config
self.feed_forward1 = Gemma4AudioFeedForward(config)
self.feed_forward2 = Gemma4AudioFeedForward(config)
self.self_attn = Gemma4AudioAttention(config, layer_idx)
self.lconv1d = Gemma4AudioLightConv1d(config)
self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size)
self.norm_post_attn = Gemma4RMSNorm(config.hidden_size)
self.norm_out = Gemma4RMSNorm(config.hidden_size)
...
class Gemma4VisionPatchEmbedder(nn.Module):
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.patch_size = config.patch_size
self.position_embedding_size = config.position_embedding_size
self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False)
self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size))
def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor:
"""Prepare patch positions map for matmul with positon embedding table."""
# Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul.
...
class Gemma4VisionPooler(nn.Module):
"""Scaling and optional spatial pooling for vision encodings"""
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.root_hidden_size = self.hidden_size**0.5
def _avg_pool_by_positions(
self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
2D spatial pooling according to patch positions.
Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between
...
class Gemma4VisionMLP(Gemma3MLP):
def __init__(self, config: Gemma4VisionConfig):
super().__init__(self, config)
self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size)
def apply_multidimensional_rope(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: torch.Tensor,
unsqueeze_dim: int = 2,
...
class Gemma4VisionRotaryEmbedding(LlamaRotaryEmbedding):
@staticmethod
def compute_default_rope_parameters(
config: Gemma4VisionConfig | None = None,
device: torch.device | None = None,
seq_len: int | None = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
...
class Gemma4VisionAttention(Gemma3Attention):
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
super().__init__(self, config, layer_idx)
del self.attn_logit_softcapping
del self.sliding_window
del self.is_sliding
self.scaling = 1.0
self.is_causal = False
self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_attention_heads * self.head_dim)
self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
self.o_proj = Gemma4ClippableLinear(config, config.num_attention_heads * self.head_dim, config.hidden_size)
self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
...
class Gemma4VisionEncoderLayer(Gemma3DecoderLayer):
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
super().__init__(self, config, layer_idx)
self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx)
self.mlp = Gemma4VisionMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
...
class Gemma4VisionEncoder(nn.Module):
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.config = config
self.num_layers = config.num_hidden_layers
self.rotary_emb = Gemma4VisionRotaryEmbedding(config)
self.layers = nn.ModuleList(
[Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)]
)
def forward(
self,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor,
...
class Gemma4TextMLP(Gemma3MLP):
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers
is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
super().__init__()
self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)
class Gemma4TextRotaryEmbedding(Gemma3RotaryEmbedding):
def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None):
nn.Module.__init__(self)
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
...
@use_kernelized_func(apply_rotary_pos_emb)
class Gemma4TextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
super().__init__()
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
self.config = config
self.layer_idx = layer_idx
self.is_sliding = self.layer_type == "sliding_attention"
self.sliding_window = config.sliding_window if self.is_sliding else None
self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding
...
class Gemma4TextExperts(MixtralExperts):
def __init__(self, config: Gemma4TextConfig):
super().__init__()
self.num_experts = config.num_experts
self.intermediate_dim = config.moe_intermediate_size
self.act_fn = ACT2FN[config.hidden_activation]
class Gemma4TextRouter(nn.Module):
def __init__(self, config: Gemma4TextConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.scalar_root_size = self.hidden_size**-0.5
...
class Gemma4TextDecoderLayer(Gemma3DecoderLayer):
def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx)
self.mlp = Gemma4TextMLP(config, layer_idx)
self.register_buffer("layer_scalar", torch.ones(1))
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
if self.hidden_size_per_layer_input:
self.act_fn = ACT2FN[config.hidden_activation]
self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
self.post_per_layer_input_norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
...
class Gemma4TextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
pass
# ---- Model Classes ----
class Gemma4PreTrainedModel(Gemma3nPreTrainedModel):
_no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
input_modalities = ("image", "text", "video", "audio")
_can_record_outputs = None # override
_skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
@torch.no_grad()
...
@auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.")
class Gemma4TextModel(Gemma3TextModel):
config: Gemma4TextConfig
_can_record_outputs = {
"router_logits": OutputRecorder(Gemma4TextRouter, index=0),
"hidden_states": Gemma4TextDecoderLayer,
"attentions": Gemma4TextAttention,
}
def __init__(self, config: Gemma4TextConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[Gemma4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
...
@auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.")
class Gemma4ForCausalLM(Gemma3ForCausalLM):
base_model_prefix = "model"
def __init__(self, config: Gemma4TextConfig):
super().__init__(config)
# Grab the ones from the child
self._keys_to_ignore_on_load_unexpected = [
f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
]
class Gemma4AudioModel(Gemma4PreTrainedModel):
"""An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture."""
...
class Gemma4VisionModel(Gemma4PreTrainedModel):
"""The Gemma 4 Vision Encoder."""
config = Gemma4VisionConfig
_can_record_outputs = {
"hidden_states": Gemma4VisionEncoderLayer,
"attentions": Gemma4VisionAttention,
}
def __init__(self, config: Gemma4VisionConfig):
super().__init__(config)
self.patch_embedder = Gemma4VisionPatchEmbedder(config)
self.encoder = Gemma4VisionEncoder(config)
self.pooler = Gemma4VisionPooler(config)
...
class Gemma4MultimodalEmbedder(Gemma3nMultimodalEmbedder):
def __init__(
self,
multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig,
text_config: Gemma4TextConfig,
):
# Audio tower may use a different output dimension (output_proj_dims) than the
# internal hidden_size. Use the tower-specific dimension if specified.
super().__init__(multimodal_config, text_config)
del self.embedding
del self.hard_embedding_norm
del self.soft_embedding_norm
del self.vocab_offset
del self.vocab_size
...
def token_type_ids_mask_function(
token_type_ids: torch.Tensor | None,
image_group_ids: torch.Tensor | None,
) -> Callable | None:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
not start and end indices.
"""
# Do not return an additional mask in this case
if token_type_ids is None:
return None
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
seq_length = image_group_ids.shape[-1]
...
def create_causal_mask_mapping(
config: PreTrainedConfig,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor | None,
past_key_values: Cache | None,
position_ids: torch.Tensor | None,
mm_token_type_ids: torch.Tensor | None = None,
pixel_values: torch.FloatTensor | None = None,
is_training: bool = False,
is_first_iteration: bool | None = None,
**kwargs,
) -> dict:
"""
Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
...
@auto_docstring(
custom_intro="""
The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a
language modeling head.
"""
)
class Gemma4Model(Gemma3nModel):
def __init__(self, config: Gemma4Config):
super().__init__(config)
del self.vision_tower
del self.embed_vision
self.vision_tower = AutoModel.from_config(config.vision_config) if config.vision_config is not None else None
self.embed_vision = (
Gemma4MultimodalEmbedder(config.vision_config, config.text_config)
...
@auto_docstring(
custom_intro="""
The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling
head.
"""
)
class Gemma4ForConditionalGeneration(Gemma3nForConditionalGeneration):
base_model_prefix = "model"
def __init__(self, config: Gemma4Config):
super().__init__(config)
# Grab the ones from the child
self._keys_to_ignore_on_load_unexpected = [
f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
...