eecebe7ef5
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
724 lines
28 KiB
Python
724 lines
28 KiB
Python
# === HEADER (license + imports) ===
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from src/transformers/models/gemma4/modular_gemma4.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_gemma4.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import math
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from functools import cached_property
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from ... import initialization as init
|
|
from ...activations import ACT2FN
|
|
from ...cache_utils import Cache, DynamicCache
|
|
from ...configuration_utils import PreTrainedConfig
|
|
from ...generation import GenerationMixin
|
|
from ...integrations import use_experts_implementation, use_kernelized_func
|
|
from ...masking_utils import (
|
|
create_bidirectional_mask,
|
|
create_causal_mask,
|
|
create_masks_for_generate,
|
|
create_sliding_window_causal_mask,
|
|
)
|
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
|
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...utils import (
|
|
ModelOutput,
|
|
TransformersKwargs,
|
|
auto_docstring,
|
|
can_return_tuple,
|
|
is_accelerate_available,
|
|
torch_compilable_check,
|
|
)
|
|
from ...utils.generic import maybe_autocast, merge_with_config_defaults
|
|
from ...utils.output_capturing import OutputRecorder, capture_outputs
|
|
from ..auto.modeling_auto import AutoModel
|
|
from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig
|
|
|
|
|
|
if is_accelerate_available():
|
|
from accelerate.hooks import add_hook_to_module
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for Gemma4 outputs, with hidden states and attentions.
|
|
"""
|
|
)
|
|
class Gemma4ModelOutputWithPast(BaseModelOutputWithPast):
|
|
r"""
|
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
|
`past_key_values` input) to speed up sequential decoding.
|
|
image_hidden_states (`torch.FloatTensor`, *optional*):
|
|
|
|
# === CLASS/FUNCTION OUTLINE (signatures + short body) ===
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for Gemma4 outputs, with hidden states and attentions.
|
|
"""
|
|
)
|
|
class Gemma4ModelOutputWithPast(BaseModelOutputWithPast):
|
|
r"""
|
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
|
`past_key_values` input) to speed up sequential decoding.
|
|
image_hidden_states (`torch.FloatTensor`, *optional*):
|
|
...
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for Gemma4 causal language model (or autoregressive) outputs.
|
|
"""
|
|
)
|
|
class Gemma4CausalLMOutputWithPast(ModelOutput):
|
|
r"""
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
|
Language modeling loss (for next-token prediction).
|
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
|
...
|
|
|
|
@dataclass
|
|
@auto_docstring
|
|
class Gemma4AudioModelOutput(BaseModelOutputWithPooling):
|
|
r"""
|
|
attention_mask (`torch.BoolTensor`, *optional*):
|
|
A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding.
|
|
"""
|
|
|
|
attention_mask: torch.BoolTensor | None = None
|
|
|
|
|
|
class Gemma4ClippableLinear(nn.Module):
|
|
def __init__(
|
|
self,
|
|
...
|
|
|
|
class Gemma4RMSNorm(nn.Module):
|
|
def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.with_scale = with_scale
|
|
|
|
if self.with_scale:
|
|
self.weight = nn.Parameter(torch.ones(dim), requires_grad=True)
|
|
|
|
def _norm(self, hidden_states: torch.Tensor):
|
|
mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps
|
|
# Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX
|
|
return hidden_states * torch.pow(mean_squared, -0.5)
|
|
|
|
...
|
|
|
|
class Gemma4AudioRelPositionalEncoding(nn.Module):
|
|
"""Sinusoidal relative positional encoding for the audio encoder.
|
|
|
|
Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with
|
|
concatenated [sin..., cos...] layout matching the original Gemma4 convention.
|
|
"""
|
|
|
|
inv_timescales: torch.Tensor
|
|
|
|
def __init__(self, config: Gemma4AudioConfig):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.context_size = (
|
|
config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right
|
|
...
|
|
|
|
class Gemma4AudioAttention(nn.Module):
|
|
"""Chunked local attention with relative position bias"""
|
|
|
|
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.attention_logits_soft_cap = config.attention_logit_cap
|
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
self.num_heads = config.num_attention_heads
|
|
|
|
self.q_scale = (self.head_dim**-0.5) / math.log(2)
|
|
self.k_scale = math.log(1 + math.e) / math.log(2)
|
|
|
|
...
|
|
|
|
class Gemma4AudioSubSampleConvProjectionLayer(nn.Module):
|
|
def __init__(self, in_channels, out_channels, norm_eps):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=(3, 3),
|
|
stride=(2, 2),
|
|
padding=1,
|
|
bias=False,
|
|
)
|
|
self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False)
|
|
self.act = nn.ReLU()
|
|
|
|
...
|
|
|
|
class Gemma4AudioSubSampleConvProjection(nn.Module):
|
|
def __init__(self, config: Gemma4AudioConfig):
|
|
super().__init__()
|
|
self.layer0 = Gemma4AudioSubSampleConvProjectionLayer(
|
|
in_channels=1,
|
|
out_channels=config.subsampling_conv_channels[0],
|
|
norm_eps=config.rms_norm_eps,
|
|
)
|
|
self.layer1 = Gemma4AudioSubSampleConvProjectionLayer(
|
|
in_channels=config.subsampling_conv_channels[0],
|
|
out_channels=config.subsampling_conv_channels[1],
|
|
norm_eps=config.rms_norm_eps,
|
|
)
|
|
proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1]
|
|
...
|
|
|
|
class Gemma4AudioFeedForward(nn.Module):
|
|
def __init__(self, config: Gemma4AudioConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4)
|
|
self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size)
|
|
|
|
self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size)
|
|
self.post_layer_norm = Gemma4RMSNorm(config.hidden_size)
|
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
|
|
self.gradient_clipping = config.gradient_clipping
|
|
self.post_layer_scale = config.residual_weight
|
|
...
|
|
|
|
class Gemma4AudioCausalConv1d(nn.Conv1d):
|
|
# def __init__(
|
|
# self,
|
|
# in_channels: int,
|
|
# out_channels: int,
|
|
# kernel_size: int,
|
|
# # cache_key: str,
|
|
# stride: int = 1,
|
|
# dilation: int = 1,
|
|
# bias: bool = True,
|
|
# ):
|
|
# super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias)
|
|
# self.cache_key = cache_key
|
|
|
|
...
|
|
|
|
class Gemma4AudioLightConv1d(nn.Module):
|
|
def __init__(self, config: Gemma4AudioConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2)
|
|
self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
|
|
self.depthwise_conv1d = Gemma4AudioCausalConv1d(
|
|
in_channels=config.hidden_size,
|
|
out_channels=config.hidden_size,
|
|
kernel_size=config.conv_kernel_size,
|
|
groups=config.hidden_size,
|
|
bias=False,
|
|
)
|
|
...
|
|
|
|
class Gemma4AudioLayer(nn.Module):
|
|
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.feed_forward1 = Gemma4AudioFeedForward(config)
|
|
self.feed_forward2 = Gemma4AudioFeedForward(config)
|
|
self.self_attn = Gemma4AudioAttention(config, layer_idx)
|
|
self.lconv1d = Gemma4AudioLightConv1d(config)
|
|
|
|
self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size)
|
|
self.norm_post_attn = Gemma4RMSNorm(config.hidden_size)
|
|
self.norm_out = Gemma4RMSNorm(config.hidden_size)
|
|
|
|
...
|
|
|
|
class Gemma4VisionPatchEmbedder(nn.Module):
|
|
def __init__(self, config: Gemma4VisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.patch_size = config.patch_size
|
|
self.position_embedding_size = config.position_embedding_size
|
|
|
|
self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False)
|
|
self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size))
|
|
|
|
def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor:
|
|
"""Prepare patch positions map for matmul with positon embedding table."""
|
|
# Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul.
|
|
...
|
|
|
|
class Gemma4VisionPooler(nn.Module):
|
|
"""Scaling and optional spatial pooling for vision encodings"""
|
|
|
|
def __init__(self, config: Gemma4VisionConfig):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.root_hidden_size = self.hidden_size**0.5
|
|
|
|
def _avg_pool_by_positions(
|
|
self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
2D spatial pooling according to patch positions.
|
|
Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between
|
|
...
|
|
|
|
class Gemma4VisionMLP(nn.Module):
|
|
def __init__(self, config: Gemma4VisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.intermediate_size
|
|
self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
|
|
self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
|
|
self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size)
|
|
self.act_fn = ACT2FN[config.hidden_activation]
|
|
|
|
def forward(self, x):
|
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
return down_proj
|
|
...
|
|
|
|
class Gemma4VisionRotaryEmbedding(nn.Module):
|
|
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
|
|
|
def __init__(self, config: Gemma4VisionConfig, device=None):
|
|
super().__init__()
|
|
self.max_seq_len_cached = config.max_position_embeddings
|
|
self.original_max_seq_len = config.max_position_embeddings
|
|
|
|
self.config = config
|
|
|
|
self.rope_type = self.config.rope_parameters["rope_type"]
|
|
rope_init_fn: Callable = self.compute_default_rope_parameters
|
|
if self.rope_type != "default":
|
|
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
|
...
|
|
|
|
def rotate_half(x):
|
|
"""Rotates half the hidden dims of the input."""
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1):
|
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
|
|
Args:
|
|
x (`torch.Tensor`): The tensor to embed.
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
...
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
"""
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
"""
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
if n_rep == 1:
|
|
return hidden_states
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
...
|
|
|
|
def apply_multidimensional_rope(
|
|
x: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
unsqueeze_dim: int = 2,
|
|
) -> torch.Tensor:
|
|
"""Applies multidimensional RoPE to inputs.
|
|
|
|
Args:
|
|
x (`torch.Tensor`): The tensor to embed.
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
position_ids (`torch.Tensor`, *optional*):
|
|
...
|
|
|
|
@use_kernelized_func(apply_rotary_pos_emb)
|
|
class Gemma4VisionAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
|
self.scaling = 1.0
|
|
self.attention_dropout = self.config.attention_dropout
|
|
self.is_causal = False
|
|
...
|
|
|
|
class Gemma4VisionEncoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.layer_idx = layer_idx
|
|
self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx)
|
|
self.mlp = Gemma4VisionMLP(config)
|
|
self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
|
self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
...
|
|
|
|
class Gemma4VisionEncoder(nn.Module):
|
|
def __init__(self, config: Gemma4VisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.num_layers = config.num_hidden_layers
|
|
self.rotary_emb = Gemma4VisionRotaryEmbedding(config)
|
|
self.layers = nn.ModuleList(
|
|
[Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
inputs_embeds: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
...
|
|
|
|
class Gemma4TextMLP(nn.Module):
|
|
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
|
|
super().__init__()
|
|
first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers
|
|
is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
|
|
use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
self.act_fn = ACT2FN[config.hidden_activation]
|
|
|
|
...
|
|
|
|
class Gemma4TextRotaryEmbedding(nn.Module):
|
|
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
|
|
|
def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None):
|
|
super().__init__()
|
|
self.max_seq_len_cached = config.max_position_embeddings
|
|
self.original_max_seq_len = config.max_position_embeddings
|
|
|
|
self.config = config
|
|
self.layer_types = set(config.layer_types)
|
|
self.rope_init_fns: dict[str, Callable[..., tuple[torch.Tensor, float]]] = {}
|
|
self.rope_type: dict[str, str] = {}
|
|
|
|
for layer_type in self.layer_types:
|
|
...
|
|
|
|
@use_kernelized_func(apply_rotary_pos_emb)
|
|
class Gemma4TextAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.is_sliding = self.layer_type == "sliding_attention"
|
|
self.sliding_window = config.sliding_window if self.is_sliding else None
|
|
|
|
self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
|
|
self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding
|
|
...
|
|
|
|
@use_experts_implementation
|
|
class Gemma4TextExperts(nn.Module):
|
|
"""Collection of expert weights stored as 3D tensors."""
|
|
|
|
def __init__(self, config: Gemma4TextConfig):
|
|
super().__init__()
|
|
self.num_experts = config.num_experts
|
|
self.hidden_dim = config.hidden_size
|
|
self.intermediate_dim = config.moe_intermediate_size
|
|
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
|
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
|
self.act_fn = ACT2FN[config.hidden_activation]
|
|
|
|
def forward(
|
|
...
|
|
|
|
class Gemma4TextRouter(nn.Module):
|
|
def __init__(self, config: Gemma4TextConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.scalar_root_size = self.hidden_size**-0.5
|
|
self.eps = config.rms_norm_eps
|
|
|
|
self.norm = Gemma4RMSNorm(self.hidden_size, eps=self.eps, with_scale=False)
|
|
self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
|
self.scale = nn.Parameter(torch.ones(self.hidden_size))
|
|
self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
...
|
|
|
|
class Gemma4TextDecoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.layer_idx = layer_idx
|
|
self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx)
|
|
self.mlp = Gemma4TextMLP(config, layer_idx)
|
|
self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
|
self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
|
self.register_buffer("layer_scalar", torch.ones(1))
|
|
|
|
...
|
|
|
|
class Gemma4TextScaledWordEmbedding(nn.Embedding):
|
|
"""
|
|
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
|
"""
|
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
|
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
|
self.scalar_embed_scale = embed_scale
|
|
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
|
|
|
def forward(self, input_ids: torch.Tensor):
|
|
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
|
|
|
|
|
|
...
|
|
|
|
@auto_docstring
|
|
class Gemma4PreTrainedModel(PreTrainedModel):
|
|
config: Gemma4Config
|
|
base_model_prefix = "model"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
|
|
_skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
|
|
_supports_flash_attn = True
|
|
_supports_sdpa = True
|
|
_supports_flex_attn = True
|
|
|
|
_can_compile_fullgraph = True
|
|
_supports_attention_backend = True
|
|
_can_record_outputs = None # override
|
|
...
|
|
|
|
@auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.")
|
|
class Gemma4TextModel(Gemma4PreTrainedModel):
|
|
config: Gemma4TextConfig
|
|
input_modalities = ("text",)
|
|
_can_record_outputs = {
|
|
"router_logits": OutputRecorder(Gemma4TextRouter, index=0),
|
|
"hidden_states": Gemma4TextDecoderLayer,
|
|
"attentions": Gemma4TextAttention,
|
|
}
|
|
|
|
def __init__(self, config: Gemma4TextConfig):
|
|
super().__init__(config)
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
...
|
|
|
|
@auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.")
|
|
class Gemma4ForCausalLM(Gemma4PreTrainedModel, GenerationMixin):
|
|
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
|
_tp_plan = {"lm_head": "colwise_gather_output"}
|
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
|
config: Gemma4TextConfig
|
|
base_model_prefix = "model"
|
|
|
|
def __init__(self, config: Gemma4TextConfig):
|
|
super().__init__(config)
|
|
self.model = Gemma4TextModel(config)
|
|
self.vocab_size = config.vocab_size
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
# Grab the ones from the child
|
|
...
|
|
|
|
def sliding_window_mask_function(sliding_window: tuple[int, int]) -> Callable:
|
|
"""
|
|
This creates uni/bidirectional attention mask with sliding window.
|
|
"""
|
|
|
|
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
|
left_window_size, right_window_size = sliding_window
|
|
|
|
dist = q_idx - kv_idx
|
|
left_mask = (dist >= 0) & (dist < left_window_size)
|
|
right_mask = (dist < 0) & (-dist < right_window_size)
|
|
return left_mask | right_mask
|
|
|
|
return inner_mask
|
|
...
|
|
|
|
class Gemma4AudioModel(Gemma4PreTrainedModel):
|
|
"""An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture."""
|
|
|
|
config: Gemma4AudioConfig
|
|
main_input_name = "input_features"
|
|
base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained()
|
|
_can_record_outputs = {
|
|
"hidden_states": Gemma4AudioLayer,
|
|
"attentions": Gemma4AudioAttention,
|
|
}
|
|
|
|
def __init__(self, config: Gemma4AudioConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
...
|
|
|
|
class Gemma4VisionModel(Gemma4PreTrainedModel):
|
|
"""The Gemma 4 Vision Encoder."""
|
|
|
|
config = Gemma4VisionConfig
|
|
_can_record_outputs = {
|
|
"hidden_states": Gemma4VisionEncoderLayer,
|
|
"attentions": Gemma4VisionAttention,
|
|
}
|
|
|
|
def __init__(self, config: Gemma4VisionConfig):
|
|
super().__init__(config)
|
|
self.patch_embedder = Gemma4VisionPatchEmbedder(config)
|
|
self.encoder = Gemma4VisionEncoder(config)
|
|
self.pooler = Gemma4VisionPooler(config)
|
|
...
|
|
|
|
class Gemma4MultimodalEmbedder(nn.Module):
|
|
"""Embeds token ids or soft tokens for multimodal content into language model space."""
|
|
|
|
def __init__(
|
|
self,
|
|
multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig,
|
|
text_config: Gemma4TextConfig,
|
|
):
|
|
super().__init__()
|
|
|
|
self.multimodal_hidden_size = getattr(multimodal_config, "output_proj_dims", multimodal_config.hidden_size)
|
|
self.eps = multimodal_config.rms_norm_eps
|
|
self.text_hidden_size = text_config.hidden_size
|
|
self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
|
|
...
|
|
|
|
def token_type_ids_mask_function(
|
|
token_type_ids: torch.Tensor | None,
|
|
image_group_ids: torch.Tensor | None,
|
|
) -> Callable | None:
|
|
"""
|
|
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
|
not start and end indices.
|
|
"""
|
|
# Do not return an additional mask in this case
|
|
if token_type_ids is None:
|
|
return None
|
|
|
|
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
|
seq_length = image_group_ids.shape[-1]
|
|
...
|
|
|
|
def create_causal_mask_mapping(
|
|
config: PreTrainedConfig,
|
|
inputs_embeds: torch.Tensor,
|
|
attention_mask: torch.Tensor | None,
|
|
past_key_values: Cache | None,
|
|
position_ids: torch.Tensor | None,
|
|
mm_token_type_ids: torch.Tensor | None = None,
|
|
pixel_values: torch.FloatTensor | None = None,
|
|
is_training: bool = False,
|
|
is_first_iteration: bool | None = None,
|
|
**kwargs,
|
|
) -> dict:
|
|
"""
|
|
Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
|
|
...
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a
|
|
language modeling head.
|
|
"""
|
|
)
|
|
class Gemma4Model(Gemma4PreTrainedModel):
|
|
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
|
accepts_loss_kwargs = False
|
|
|
|
def __init__(self, config: Gemma4Config):
|
|
super().__init__(config)
|
|
self.vocab_size = config.text_config.vocab_size
|
|
|
|
...
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling
|
|
head.
|
|
"""
|
|
)
|
|
class Gemma4ForConditionalGeneration(Gemma4PreTrainedModel, GenerationMixin):
|
|
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
|
accepts_loss_kwargs = False
|
|
base_model_prefix = "model"
|
|
|
|
def __init__(self, config: Gemma4Config):
|
|
super().__init__(config)
|
|
self.model = Gemma4Model(config)
|
|
...
|
|
|