# === HEADER (license + imports) === # Copyright 2026 the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections.abc import Callable from dataclasses import dataclass from functools import cached_property import torch from torch import nn from torch.nn import functional as F from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...integrations import use_kernelized_func from ...masking_utils import ( create_bidirectional_mask, create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, auto_docstring, can_return_tuple, is_accelerate_available, logging, torch_compilable_check, ) from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs from ..auto.modeling_auto import AutoModel from ..gemma3.modeling_gemma3 import ( Gemma3Attention, Gemma3DecoderLayer, Gemma3ForCausalLM, Gemma3MLP, Gemma3RotaryEmbedding, Gemma3TextModel, Gemma3TextScaledWordEmbedding, ) from ..gemma3n.modeling_gemma3n import ( Gemma3nCausalLMOutputWithPast, Gemma3nForConditionalGeneration, Gemma3nModel, Gemma3nModelOutputWithPast, Gemma3nMultimodalEmbedder, Gemma3nPreTrainedModel, Gemma3nRMSNorm, apply_rotary_pos_emb, eager_attention_forward, ) from ..llama.modeling_llama import LlamaRotaryEmbedding from ..mixtral.modeling_mixtral import MixtralExperts from ..moonshine_streaming.modeling_moonshine_streaming import sliding_window_mask_function from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig if is_accelerate_available(): pass # === CLASS/FUNCTION OUTLINE (signatures + short body) === class Gemma4ModelOutputWithPast(Gemma3nModelOutputWithPast): pass class Gemma4CausalLMOutputWithPast(Gemma3nCausalLMOutputWithPast): pass @dataclass @auto_docstring class Gemma4AudioModelOutput(BaseModelOutputWithPooling): r""" attention_mask (`torch.BoolTensor`, *optional*): A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding. ... class Gemma4ClippableLinear(nn.Module): def __init__( self, config: Gemma4VisionConfig | Gemma4AudioConfig, in_features: int, out_features: int, ) -> None: super().__init__() self.use_clipped_linears = config.use_clipped_linears self.linear = nn.Linear(in_features, out_features, bias=False) if self.use_clipped_linears: self.register_buffer("input_min", torch.tensor(-float("inf"))) self.register_buffer("input_max", torch.tensor(float("inf"))) ... class Gemma4RMSNorm(Gemma3nRMSNorm): pass class Gemma4AudioRelPositionalEncoding(nn.Module): """Sinusoidal relative positional encoding for the audio encoder. Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with concatenated [sin..., cos...] layout matching the original Gemma4 convention. """ inv_timescales: torch.Tensor def __init__(self, config: Gemma4AudioConfig): ... class Gemma4AudioAttention(nn.Module): """Chunked local attention with relative position bias""" def __init__(self, config: Gemma4AudioConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.attention_logits_soft_cap = config.attention_logit_cap self.head_dim = config.hidden_size // config.num_attention_heads self.num_heads = config.num_attention_heads self.q_scale = (self.head_dim**-0.5) / math.log(2) self.k_scale = math.log(1 + math.e) / math.log(2) ... class Gemma4AudioSubSampleConvProjectionLayer(nn.Module): def __init__(self, in_channels, out_channels, norm_eps): super().__init__() self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(2, 2), padding=1, bias=False, ) self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False) self.act = nn.ReLU() ... class Gemma4AudioSubSampleConvProjection(nn.Module): def __init__(self, config: Gemma4AudioConfig): super().__init__() self.layer0 = Gemma4AudioSubSampleConvProjectionLayer( in_channels=1, out_channels=config.subsampling_conv_channels[0], norm_eps=config.rms_norm_eps, ) self.layer1 = Gemma4AudioSubSampleConvProjectionLayer( in_channels=config.subsampling_conv_channels[0], out_channels=config.subsampling_conv_channels[1], norm_eps=config.rms_norm_eps, ) proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1] ... class Gemma4AudioFeedForward(nn.Module): def __init__(self, config: Gemma4AudioConfig): super().__init__() self.config = config self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4) self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size) self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size) self.post_layer_norm = Gemma4RMSNorm(config.hidden_size) self.act_fn = ACT2FN[config.hidden_act] self.gradient_clipping = config.gradient_clipping self.post_layer_scale = config.residual_weight ... class Gemma4AudioCausalConv1d(nn.Conv1d): # def __init__( # self, # in_channels: int, # out_channels: int, # kernel_size: int, # # cache_key: str, # stride: int = 1, # dilation: int = 1, # bias: bool = True, # ): # super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias) # self.cache_key = cache_key ... class Gemma4AudioLightConv1d(nn.Module): def __init__(self, config: Gemma4AudioConfig): super().__init__() self.config = config self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2) self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size) self.depthwise_conv1d = Gemma4AudioCausalConv1d( in_channels=config.hidden_size, out_channels=config.hidden_size, kernel_size=config.conv_kernel_size, groups=config.hidden_size, bias=False, ) ... class Gemma4AudioLayer(nn.Module): def __init__(self, config: Gemma4AudioConfig, layer_idx: int): super().__init__() self.config = config self.feed_forward1 = Gemma4AudioFeedForward(config) self.feed_forward2 = Gemma4AudioFeedForward(config) self.self_attn = Gemma4AudioAttention(config, layer_idx) self.lconv1d = Gemma4AudioLightConv1d(config) self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size) self.norm_post_attn = Gemma4RMSNorm(config.hidden_size) self.norm_out = Gemma4RMSNorm(config.hidden_size) ... class Gemma4VisionPatchEmbedder(nn.Module): def __init__(self, config: Gemma4VisionConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.patch_size = config.patch_size self.position_embedding_size = config.position_embedding_size self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False) self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size)) def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor: """Prepare patch positions map for matmul with positon embedding table.""" # Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul. ... class Gemma4VisionPooler(nn.Module): """Scaling and optional spatial pooling for vision encodings""" def __init__(self, config: Gemma4VisionConfig): super().__init__() self.hidden_size = config.hidden_size self.root_hidden_size = self.hidden_size**0.5 def _avg_pool_by_positions( self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int ) -> tuple[torch.Tensor, torch.Tensor]: """ 2D spatial pooling according to patch positions. Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between ... class Gemma4VisionMLP(Gemma3MLP): def __init__(self, config: Gemma4VisionConfig): super().__init__(self, config) self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size) self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size) self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size) def apply_multidimensional_rope( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor, unsqueeze_dim: int = 2, ... class Gemma4VisionRotaryEmbedding(LlamaRotaryEmbedding): @staticmethod def compute_default_rope_parameters( config: Gemma4VisionConfig | None = None, device: torch.device | None = None, seq_len: int | None = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation Args: config ([`~transformers.PreTrainedConfig`]): The model configuration. device (`torch.device`): The device to use for initialization of the inverse frequencies. ... class Gemma4VisionAttention(Gemma3Attention): def __init__(self, config: Gemma4VisionConfig, layer_idx: int): super().__init__(self, config, layer_idx) del self.attn_logit_softcapping del self.sliding_window del self.is_sliding self.scaling = 1.0 self.is_causal = False self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim) self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_attention_heads * self.head_dim) self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim) self.o_proj = Gemma4ClippableLinear(config, config.num_attention_heads * self.head_dim, config.hidden_size) self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) ... class Gemma4VisionEncoderLayer(Gemma3DecoderLayer): def __init__(self, config: Gemma4VisionConfig, layer_idx: int): super().__init__(self, config, layer_idx) self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx) self.mlp = Gemma4VisionMLP(config) def forward( self, hidden_states: torch.Tensor, position_embeddings: torch.Tensor = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: ... class Gemma4VisionEncoder(nn.Module): def __init__(self, config: Gemma4VisionConfig): super().__init__() self.config = config self.num_layers = config.num_hidden_layers self.rotary_emb = Gemma4VisionRotaryEmbedding(config) self.layers = nn.ModuleList( [Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)] ) def forward( self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, ... class Gemma4TextMLP(Gemma3MLP): def __init__(self, config: Gemma4TextConfig, layer_idx: int): first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer super().__init__() self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1) class Gemma4TextRotaryEmbedding(Gemma3RotaryEmbedding): def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None): nn.Module.__init__(self) self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings ... @use_kernelized_func(apply_rotary_pos_emb) class Gemma4TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Gemma4TextConfig, layer_idx: int): super().__init__() self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None self.config = config self.layer_idx = layer_idx self.is_sliding = self.layer_type == "sliding_attention" self.sliding_window = config.sliding_window if self.is_sliding else None self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding ... class Gemma4TextExperts(MixtralExperts): def __init__(self, config: Gemma4TextConfig): super().__init__() self.num_experts = config.num_experts self.intermediate_dim = config.moe_intermediate_size self.act_fn = ACT2FN[config.hidden_activation] class Gemma4TextRouter(nn.Module): def __init__(self, config: Gemma4TextConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.scalar_root_size = self.hidden_size**-0.5 ... class Gemma4TextDecoderLayer(Gemma3DecoderLayer): def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int): super().__init__(config, layer_idx) self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx) self.mlp = Gemma4TextMLP(config, layer_idx) self.register_buffer("layer_scalar", torch.ones(1)) self.hidden_size_per_layer_input = config.hidden_size_per_layer_input if self.hidden_size_per_layer_input: self.act_fn = ACT2FN[config.hidden_activation] self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False) self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False) self.post_per_layer_input_norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) ... class Gemma4TextScaledWordEmbedding(Gemma3TextScaledWordEmbedding): pass # ---- Model Classes ---- class Gemma4PreTrainedModel(Gemma3nPreTrainedModel): _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"] input_modalities = ("image", "text", "video", "audio") _can_record_outputs = None # override _skip_keys_device_placement = ["past_key_values", "shared_kv_states"] @torch.no_grad() ... @auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.") class Gemma4TextModel(Gemma3TextModel): config: Gemma4TextConfig _can_record_outputs = { "router_logits": OutputRecorder(Gemma4TextRouter, index=0), "hidden_states": Gemma4TextDecoderLayer, "attentions": Gemma4TextAttention, } def __init__(self, config: Gemma4TextConfig): super().__init__(config) self.layers = nn.ModuleList( [Gemma4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) ... @auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.") class Gemma4ForCausalLM(Gemma3ForCausalLM): base_model_prefix = "model" def __init__(self, config: Gemma4TextConfig): super().__init__(config) # Grab the ones from the child self._keys_to_ignore_on_load_unexpected = [ f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected ] class Gemma4AudioModel(Gemma4PreTrainedModel): """An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture.""" ... class Gemma4VisionModel(Gemma4PreTrainedModel): """The Gemma 4 Vision Encoder.""" config = Gemma4VisionConfig _can_record_outputs = { "hidden_states": Gemma4VisionEncoderLayer, "attentions": Gemma4VisionAttention, } def __init__(self, config: Gemma4VisionConfig): super().__init__(config) self.patch_embedder = Gemma4VisionPatchEmbedder(config) self.encoder = Gemma4VisionEncoder(config) self.pooler = Gemma4VisionPooler(config) ... class Gemma4MultimodalEmbedder(Gemma3nMultimodalEmbedder): def __init__( self, multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig, text_config: Gemma4TextConfig, ): # Audio tower may use a different output dimension (output_proj_dims) than the # internal hidden_size. Use the tower-specific dimension if specified. super().__init__(multimodal_config, text_config) del self.embedding del self.hard_embedding_norm del self.soft_embedding_norm del self.vocab_offset del self.vocab_size ... def token_type_ids_mask_function( token_type_ids: torch.Tensor | None, image_group_ids: torch.Tensor | None, ) -> Callable | None: """ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, not start and end indices. """ # Do not return an additional mask in this case if token_type_ids is None: return None def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: seq_length = image_group_ids.shape[-1] ... def create_causal_mask_mapping( config: PreTrainedConfig, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor | None, past_key_values: Cache | None, position_ids: torch.Tensor | None, mm_token_type_ids: torch.Tensor | None = None, pixel_values: torch.FloatTensor | None = None, is_training: bool = False, is_first_iteration: bool | None = None, **kwargs, ) -> dict: """ Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping ... @auto_docstring( custom_intro=""" The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a language modeling head. """ ) class Gemma4Model(Gemma3nModel): def __init__(self, config: Gemma4Config): super().__init__(config) del self.vision_tower del self.embed_vision self.vision_tower = AutoModel.from_config(config.vision_config) if config.vision_config is not None else None self.embed_vision = ( Gemma4MultimodalEmbedder(config.vision_config, config.text_config) ... @auto_docstring( custom_intro=""" The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling head. """ ) class Gemma4ForConditionalGeneration(Gemma3nForConditionalGeneration): base_model_prefix = "model" def __init__(self, config: Gemma4Config): super().__init__(config) # Grab the ones from the child self._keys_to_ignore_on_load_unexpected = [ f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected ...