docs: add canonical tooling corpus (147 files) from Google/HF/frameworks
Five-lane parallel research pass. Each subdir under tooling/ has its own README indexing downloaded files with verified upstream sources. - google-official/: deepmind-gemma JAX examples, gemma_pytorch scripts, gemma.cpp API server docs, google-gemma/cookbook notebooks, ai.google.dev HTML snapshots, Gemma 3 tech report - huggingface/: 8 gemma-4-* model cards, chat-template .jinja files, tokenizer_config.json, transformers gemma4/ source, launch blog posts, official HF Spaces app.py - inference-frameworks/: vLLM/llama.cpp/MLX/Keras-hub/TGI/Gemini API/Vertex AI comparison, run_commands.sh with 8 working launches, 9 code snippets - gemma-family/: 12 per-variant briefs (ShieldGemma 2, CodeGemma, PaliGemma 2, Recurrent/Data/Med/TxGemma, Embedding/Translate/Function/Dolphin/SignGemma) - fine-tuning/: Unsloth Gemma 4 notebooks, Axolotl YAMLs (incl 26B-A4B MoE), TRL scripts, Google cookbook fine-tune notebooks, recipe-recommendation.md Findings that update earlier CORPUS_* docs are flagged in tooling/README.md (not applied) — notably the new <|turn>/<turn|> prompt format, gemma_pytorch abandonment, gemma.cpp Gemini-API server, transformers AutoModelForMultimodalLM, FA2 head_dim=512 break, 26B-A4B MoE quantization rules, no Gemma 4 tech report PDF yet, no Gemma-4-generation specialized siblings yet. Pre-commit secrets hook bypassed per user authorization — flagged "secrets" are base64 notebook cell outputs and example Ed25519 keys in the HDP agentic-security demo, not real credentials. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gemma4 import *
|
||||
from .feature_extraction_gemma4 import *
|
||||
from .image_processing_gemma4 import *
|
||||
from .image_processing_pil_gemma4 import *
|
||||
from .modeling_gemma4 import *
|
||||
from .processing_gemma4 import *
|
||||
from .video_processing_gemma4 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||
@@ -0,0 +1,352 @@
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from huggingface_hub.dataclasses import strict
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils.type_validators import interval
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="google/gemma-4-e2b-it")
|
||||
@strict
|
||||
class Gemma4AudioConfig(PreTrainedConfig):
|
||||
r"""
|
||||
subsampling_conv_channels (`list[int]`, defaults to `[128, 32]`):
|
||||
Channel sizes for the convolutional layers in the Sub-sample Convolution Projection.
|
||||
residual_weight (`float`, defaults to `0.5`):
|
||||
Scaling applied to hidden_states prior to combining with the residual in the feedforward.
|
||||
attention_chunk_size (`int`, defaults to `12`):
|
||||
The sub-sequence size for attention processing.
|
||||
attention_context_left (`int`, defaults to `13`):
|
||||
The leftward context size for the attention chunk.
|
||||
attention_context_right (`int`, defaults to `0`):
|
||||
The rightward context size for the attention chunk.
|
||||
attention_logit_cap (`float`, defaults to `50.0`):
|
||||
Cap applied to attention weights.
|
||||
attention_invalid_logits_value (`float`, defaults to `1e-9`):
|
||||
Value to use for invalid logits in attention.
|
||||
use_clipped_linears (`bool`, defaults to `True`):
|
||||
If true, apply clipping to the Linear layers, drawing bounds from the model checkpoint.
|
||||
gradient_clipping (`float`, defaults to `1e10`):
|
||||
Clipping value used to stabilize extremely large gradient values.
|
||||
output_proj_dims (`int`, defaults to `1536`):
|
||||
Dimension of the final linear projection from `hidden_size` to the model's output.
|
||||
"""
|
||||
|
||||
model_type = "gemma4_audio"
|
||||
|
||||
hidden_size: int = 1024
|
||||
num_hidden_layers: int = 12
|
||||
num_attention_heads: int = 8
|
||||
hidden_act: str = "silu"
|
||||
|
||||
# subsampling parameters
|
||||
subsampling_conv_channels: list[int] | tuple[int, int] = (128, 32)
|
||||
|
||||
# conformer parameters
|
||||
conv_kernel_size: int = 5
|
||||
residual_weight: float = 0.5
|
||||
attention_chunk_size: int = 12
|
||||
attention_context_left: int = 13
|
||||
attention_context_right: int = 0
|
||||
attention_logit_cap: float = 50.0
|
||||
attention_invalid_logits_value: float = -1.0e9
|
||||
|
||||
use_clipped_linears: bool = True
|
||||
rms_norm_eps: float = 1e-6
|
||||
gradient_clipping: float = 1e10
|
||||
output_proj_dims: int = 1536
|
||||
initializer_range: float = interval(min=0.0, max=1.0)(default=0.02)
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
# JSON serialization converts tuples to lists, convert back
|
||||
if isinstance(self.subsampling_conv_channels, tuple):
|
||||
self.subsampling_conv_channels = list(self.subsampling_conv_channels)
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="google/gemma-4-e2b-it")
|
||||
@strict
|
||||
class Gemma4TextConfig(PreTrainedConfig):
|
||||
r"""
|
||||
use_bidirectional_attention (`str`, *optional*):
|
||||
Controls bidirectional attention behavior. When set to `"vision"`, vision tokens
|
||||
attend bidirectionally while text tokens use causal attention. When set to `"all"`,
|
||||
all tokens use bidirectional attention.
|
||||
vocab_size_per_layer_input (`int`, defaults to 262144):
|
||||
Vocabulary size for the per-layer input embeddings (PLE). Used by models with
|
||||
per-layer residual streams where a smaller embedding is added at each decoder layer.
|
||||
hidden_size_per_layer_input (`int`, defaults to 256):
|
||||
Per-layer hidden dimension for the PLE system. The actual embedding weight has shape
|
||||
`[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]`
|
||||
because all layers are packed into a single table. See the [Gemma4](https://huggingface.co/docs/transformers/main/en/model_doc/gemma4#per-layer-embeddings-ple) docs
|
||||
for a description of the full PLE pipeline.
|
||||
num_global_key_value_heads (`int`, *optional*):
|
||||
Number of key-value heads for global (full) attention layers. If `None`, defaults
|
||||
to `num_key_value_heads`.
|
||||
global_head_dim (`int`, defaults to 512):
|
||||
Dimension of each attention head in global (full) attention layers.
|
||||
attention_k_eq_v (`bool`, defaults to `False`):
|
||||
Whether keys and values share the same projection weights. When `True`, the key
|
||||
projection output is reused as the value projection.
|
||||
num_kv_shared_layers (`int`, defaults to 0):
|
||||
Number of consecutive decoder layers that share the same key-value projections.
|
||||
A value of 0 means no sharing (each layer has independent KV projections).
|
||||
enable_moe_block (`bool`, defaults to `False`):
|
||||
Whether to enable Mixture-of-Experts (MoE) blocks in the decoder layers. When
|
||||
`True`, eligible layers will use a sparse MoE feed-forward network.
|
||||
use_double_wide_mlp (`bool`, defaults to `False`):
|
||||
Whether to use a double-width MLP with fused gate and up projections.
|
||||
top_k_experts (`int`, *optional*):
|
||||
Number of experts activated per token in MoE layers. Only used when
|
||||
`enable_moe_block=True`.
|
||||
moe_intermediate_size (`int`, *optional*):
|
||||
Intermediate (hidden) size of each expert's feed-forward network in MoE layers.
|
||||
Only used when `enable_moe_block=True`.
|
||||
"""
|
||||
|
||||
model_type = "gemma4_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
|
||||
"layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
"layers.*.experts.gate_up_proj": "packed_colwise",
|
||||
"layers.*.experts.down_proj": "rowwise",
|
||||
"layers.*.experts": "moe_tp_experts",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
vocab_size: int = 262_144
|
||||
hidden_size: int = 2304
|
||||
intermediate_size: int = 9216
|
||||
num_hidden_layers: int = 30
|
||||
num_attention_heads: int = 8
|
||||
num_key_value_heads: int = 4
|
||||
head_dim: int = 256
|
||||
hidden_activation: str = "gelu_pytorch_tanh"
|
||||
max_position_embeddings: int = 131_072
|
||||
initializer_range: float = 0.02
|
||||
rms_norm_eps: float = 1e-6
|
||||
use_cache: bool = True
|
||||
pad_token_id: int | None = 0
|
||||
eos_token_id: int | list[int] | None = 1
|
||||
bos_token_id: int | None = 2
|
||||
tie_word_embeddings: bool = True
|
||||
rope_parameters: dict | None = None
|
||||
attention_bias: bool = False
|
||||
attention_dropout: int | float | None = 0.0
|
||||
sliding_window: int = 512
|
||||
layer_types: list[str] | None = None
|
||||
final_logit_softcapping: float | None = None
|
||||
use_bidirectional_attention: Literal["all", "vision"] | None = None
|
||||
vocab_size_per_layer_input: int = 262_144
|
||||
hidden_size_per_layer_input: int = 256
|
||||
num_global_key_value_heads: int | None = None
|
||||
global_head_dim: int = 512
|
||||
attention_k_eq_v: bool = False
|
||||
num_kv_shared_layers: int = 0
|
||||
enable_moe_block: bool = False
|
||||
use_double_wide_mlp: bool = False
|
||||
num_experts: int | None = None
|
||||
top_k_experts: int | None = None
|
||||
moe_intermediate_size: int | None = None
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
if self.use_bidirectional_attention == "all":
|
||||
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
|
||||
|
||||
if self.layer_types is None:
|
||||
sliding_window_pattern = 6 # by default 5:1
|
||||
self.layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
|
||||
if self.layer_types and (last_layer_type := self.layer_types[-1]) != "full_attention":
|
||||
logger.warning(
|
||||
f"Last layer must use `full_attention`, but got `{last_layer_type}`. Forcing last layer to `full_attention`."
|
||||
)
|
||||
self.layer_types[-1] = "full_attention"
|
||||
|
||||
default_rope_params: dict[Literal["full_attention", "sliding_attention"] : dict[str, Any]] = {
|
||||
"sliding_attention": {"rope_type": "default", "rope_theta": 10_000.0},
|
||||
"full_attention": {"rope_type": "proportional", "partial_rotary_factor": 0.25, "rope_theta": 1_000_000.0},
|
||||
}
|
||||
if self.rope_parameters is None:
|
||||
self.rope_parameters = default_rope_params
|
||||
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
def convert_rope_params_to_dict(self, **kwargs):
|
||||
# No need to handle BC for new models, because they have no old-format `rope_scaling`
|
||||
return kwargs
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="google/gemma-4-e2b-it")
|
||||
@strict
|
||||
class Gemma4VisionConfig(PreTrainedConfig):
|
||||
r"""
|
||||
pooling_kernel_size (`int`, *optional*):
|
||||
Spatial pooling kernel size applied after patchification.
|
||||
position_embedding_size (`int`, defaults to 10240):
|
||||
Maximum number of position embeddings for the vision encoder. Controls the size of
|
||||
the learned 2D position embedding table used by the patch embedder.
|
||||
use_clipped_linears (`bool`, defaults to `False`):
|
||||
Whether to use weight-clipped linear layers. When enabled, linear layer weights are
|
||||
clamped to a fixed range during the forward pass to improve numerical stability.
|
||||
standardize (`bool`, defaults to `False`):
|
||||
If true, applies a bias and scale to the soft tokens returned from the pooler.
|
||||
"""
|
||||
|
||||
model_type = "gemma4_vision"
|
||||
base_model_tp_plan = {
|
||||
"encoder.layers.*.self_attn.q_proj": "colwise",
|
||||
"encoder.layers.*.self_attn.k_proj": "colwise",
|
||||
"encoder.layers.*.self_attn.v_proj": "colwise",
|
||||
"encoder.layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
|
||||
"encoder.layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
|
||||
"encoder.layers.*.self_attn.o_proj": "rowwise",
|
||||
"encoder.layers.*.mlp.gate_proj": "colwise",
|
||||
"encoder.layers.*.mlp.up_proj": "colwise",
|
||||
"encoder.layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
default_theta = 100.0
|
||||
|
||||
hidden_size: int = 768
|
||||
intermediate_size: int = 3072
|
||||
num_hidden_layers: int = 16
|
||||
num_attention_heads: int = 12
|
||||
num_key_value_heads: int = 12
|
||||
head_dim: int = 64
|
||||
hidden_activation: str = "gelu_pytorch_tanh"
|
||||
rms_norm_eps: float = 1e-6
|
||||
max_position_embeddings: int = 131_072
|
||||
attention_bias: bool | None = False
|
||||
attention_dropout: float | None = 0.0
|
||||
rope_parameters: dict | None = None
|
||||
pooling_kernel_size: int = 3
|
||||
patch_size: int = 16
|
||||
position_embedding_size: int = 10 * 1024
|
||||
use_clipped_linears: bool = False
|
||||
standardize: bool = False
|
||||
initializer_range: float = 0.02
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
if self.rope_parameters is None:
|
||||
self.rope_parameters = {"rope_type": "default", "rope_theta": 100.0}
|
||||
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="google/gemma-4-e2b-it")
|
||||
@strict
|
||||
class Gemma4Config(PreTrainedConfig):
|
||||
r"""
|
||||
boi_token_id (`int`, *optional*, defaults to 255999):
|
||||
The begin-of-image token index to wrap the image prompt.
|
||||
eoi_token_id (`int`, *optional*, defaults to 258882):
|
||||
The end-of-image token index to wrap the image prompt.
|
||||
boa_token_id (`int`, *optional*, defaults to 256000):
|
||||
The begin-of-audio token index to wrap the audio prompt.
|
||||
eoa_token_index (`int`, *optional*, defaults to 258883):
|
||||
The end-of-audio token index to wrap the audio prompt.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
>>> Gemma4AudioConfig,
|
||||
>>> Gemma4Config,
|
||||
>>> Gemma4ForConditionalGeneration,
|
||||
>>> Gemma4TextConfig,
|
||||
>>> Gemma4VisionConfig,
|
||||
>>> )
|
||||
|
||||
>>> # Initializing a Gemma 4 Audio config.
|
||||
>>> audio_config = Gemma4AudioConfig()
|
||||
|
||||
>>> # Initializing a Gemma 4 Text config.
|
||||
>>> text_config = Gemma4TextConfig()
|
||||
|
||||
>>> # Initializing a Gemma 4 vision config.
|
||||
>>> vision_config = Gemma4VisionConfig()
|
||||
|
||||
>>> # Initializing a Gemma 4 config similar to google/gemma-4-e2b-it
|
||||
>>> configuration = Gemma4Config(text_config, vision_config, audio_config)
|
||||
|
||||
>>> # Initializing a model from the google/gemma-4-e2b-it configuration
|
||||
>>> model = Gemma4ForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "gemma4"
|
||||
sub_configs = {
|
||||
"text_config": Gemma4TextConfig,
|
||||
"vision_config": Gemma4VisionConfig,
|
||||
"audio_config": Gemma4AudioConfig,
|
||||
}
|
||||
|
||||
text_config: Gemma4TextConfig | dict[str, Any] | None = None
|
||||
vision_config: Gemma4VisionConfig | dict[str, Any] | None = None
|
||||
audio_config: Gemma4AudioConfig | dict[str, Any] | None = None
|
||||
boi_token_id: int | None = 255_999
|
||||
eoi_token_id: int | None = 258_882
|
||||
image_token_id: int | None = 258_880
|
||||
video_token_id: int | None = 258_884
|
||||
boa_token_id: int | None = 256_000
|
||||
eoa_token_index: int | None = 258_883
|
||||
audio_token_id: int | None = 258_881
|
||||
initializer_range: float | None = 0.02
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
if self.text_config is None:
|
||||
self.text_config = Gemma4TextConfig()
|
||||
logger.info("text_config is None. Using default Gemma4TextConfig.")
|
||||
elif isinstance(self.text_config, dict):
|
||||
self.text_config = Gemma4TextConfig(**self.text_config)
|
||||
|
||||
if self.vision_config is None:
|
||||
logger.info("vision_config is None. Gemma4Model.vision_tower will not be initialized.")
|
||||
if isinstance(self.vision_config, dict):
|
||||
self.vision_config = Gemma4VisionConfig(**self.vision_config)
|
||||
|
||||
if self.audio_config is None:
|
||||
logger.info("audio_config is None. Gemma4Model.audio_tower will not be initialized.")
|
||||
if isinstance(self.audio_config, dict):
|
||||
self.audio_config = Gemma4AudioConfig(**self.audio_config)
|
||||
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["Gemma4AudioConfig", "Gemma4Config", "Gemma4TextConfig", "Gemma4VisionConfig"]
|
||||
@@ -0,0 +1,298 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...audio_utils import mel_filter_bank, window_function
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...utils import PaddingStrategy, TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _unfold(array: np.ndarray, dimension: int, size: int, step: int) -> np.ndarray:
|
||||
"""A basic NumPy equivalent of PyTorch's unfold for 2D arrays along the last dim."""
|
||||
if array.ndim != 2:
|
||||
raise ValueError("This unfold implementation currently supports 2D arrays (batch, time).")
|
||||
if dimension != -1 and dimension != array.ndim - 1:
|
||||
raise ValueError("This unfold implementation only supports unfolding the last dimension.")
|
||||
|
||||
batch_size, original_length = array.shape
|
||||
num_frames = (original_length - size) // step + 1
|
||||
|
||||
if num_frames <= 0:
|
||||
return np.zeros((batch_size, 0, size), dtype=array.dtype)
|
||||
|
||||
output_shape = (batch_size, num_frames, size)
|
||||
output_strides = (array.strides[0], array.strides[1] * step, array.strides[1])
|
||||
|
||||
return np.lib.stride_tricks.as_strided(array, shape=output_shape, strides=output_strides)
|
||||
|
||||
|
||||
class Gemma4AudioFeatureExtractor(SequenceFeatureExtractor):
|
||||
"""An audio feature extractor Universal Speech Models https://huggingface.co/papers/2303.01037.
|
||||
|
||||
Args:
|
||||
feature_size (`int`, *optional*, defaults to 128):
|
||||
The feature dimension of the extracted features.
|
||||
sampling_rate (`int`, *optional*, defaults to 16000):
|
||||
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
|
||||
padding_value (`float`, *optional*, defaults to 0.0):
|
||||
Padding value used to pad the audio. Should correspond to silences.
|
||||
return_attention_mask (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the attention mask for the generated MEL spectrograms.
|
||||
frame_length_ms (`float`, *optional*, defaults to 20.0):
|
||||
The length of a frame in milliseconds.
|
||||
hop_length_ms (`float`, *optional*, defaults to 10.0):
|
||||
Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
|
||||
min_frequency (`float`, *optional*, defaults to 0.0):
|
||||
The minimum frequency (in Hz) for the Mel filterbank.
|
||||
max_frequency (`float`, *optional*, defaults to 8000.0):
|
||||
The maximum frequency (in Hz) for the Mel filterbank.
|
||||
preemphasis (`float`, *optional*, defaults to 0.0):
|
||||
The preemphasis coefficient.
|
||||
preemphasis_htk_flavor (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use HTK-style preemphasis.
|
||||
fft_overdrive (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use FFT overdrive.
|
||||
dither (`float`, *optional*, defaults to 0.0):
|
||||
Adds dithering. In other words, adds a small Gaussian noise to each frame.
|
||||
E.g. use 0.0001 to add dithering with a normal distribution centered
|
||||
around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech).
|
||||
The value 0.0 means no dithering.
|
||||
Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
|
||||
the high log_mel_fbank values for signals with hard-zero sections,
|
||||
when VAD cutoff is present in the signal.
|
||||
input_scale_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor applied to the input waveform.
|
||||
mel_floor (`float`, *optional*, defaults to 0.001):
|
||||
Minimum value for Mel spectrograms to avoid log(0).
|
||||
per_bin_mean (`Optional[Sequence[float]]`, *optional*):
|
||||
Mean values for per-bin normalization.
|
||||
per_bin_stddev (`Optional[Sequence[float]]`, *optional*):
|
||||
Standard deviation values for per-bin normalization.
|
||||
"""
|
||||
|
||||
model_input_names = ["input_features", "input_features_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_size: int = 128,
|
||||
sampling_rate: int = 16_000,
|
||||
padding_value: float = 0.0,
|
||||
return_attention_mask: bool = True,
|
||||
frame_length_ms: float = 20.0,
|
||||
hop_length_ms: float = 10.0,
|
||||
min_frequency: float = 0.0,
|
||||
max_frequency: float = 8000.0,
|
||||
preemphasis: float = 0.0,
|
||||
preemphasis_htk_flavor: bool = True,
|
||||
fft_overdrive: bool = False,
|
||||
dither: float = 0.0,
|
||||
input_scale_factor: float = 1.0,
|
||||
mel_floor: float = 1e-3,
|
||||
per_bin_mean: Sequence[float] | None = None,
|
||||
per_bin_stddev: Sequence[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
feature_size=feature_size,
|
||||
sampling_rate=sampling_rate,
|
||||
padding_value=padding_value,
|
||||
return_attention_mask=return_attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.min_frequency = min_frequency
|
||||
self.max_frequency = max_frequency
|
||||
self.preemphasis = preemphasis
|
||||
self.preemphasis_htk_flavor = preemphasis_htk_flavor
|
||||
self.fft_overdrive = fft_overdrive
|
||||
self.dither = dither
|
||||
self.input_scale_factor = input_scale_factor
|
||||
self.frame_length = int(round(sampling_rate * frame_length_ms / 1000.0))
|
||||
self.hop_length = int(round(sampling_rate * hop_length_ms / 1000.0))
|
||||
self.mel_floor = np.array(mel_floor, dtype=np.float64)
|
||||
|
||||
fft_length = 2 ** math.ceil(math.log2(self.frame_length))
|
||||
if self.fft_overdrive:
|
||||
fft_length *= 2
|
||||
self.fft_length = fft_length
|
||||
|
||||
# Use periodic Hann window, matching sl.STFT default (signal.hann_window)
|
||||
# For even frame_length: window[n] = 0.5 - 0.5 * cos(2*pi*n / frame_length)
|
||||
self.window = window_function(self.frame_length).astype(np.float32)
|
||||
|
||||
# Use HuggingFace's mel_filter_bank for compatibility.
|
||||
# Suppress the expected warning about all-zero upper mel filters;
|
||||
# with fft_length=512 (257 bins) and 128 mel filters the uppermost
|
||||
# triangular filter falls between frequency bins, which is harmless.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
self.mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=self.fft_length // 2 + 1,
|
||||
num_mel_filters=feature_size,
|
||||
min_frequency=min_frequency,
|
||||
max_frequency=max_frequency,
|
||||
sampling_rate=self.sampling_rate,
|
||||
norm=None,
|
||||
mel_scale="htk",
|
||||
)
|
||||
|
||||
if per_bin_mean is not None:
|
||||
self.per_bin_mean = np.array(per_bin_mean).reshape(1, 1, feature_size)
|
||||
else:
|
||||
self.per_bin_mean = None
|
||||
|
||||
if per_bin_stddev is not None:
|
||||
self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, 1, feature_size)
|
||||
else:
|
||||
self.per_bin_stddev = None
|
||||
|
||||
def _extract_spectrogram(self, waveform: np.ndarray, attention_mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
""""""
|
||||
if waveform.ndim == 1: # If single waveform, add batch dimension
|
||||
waveform = np.expand_dims(waveform, axis=0)
|
||||
|
||||
if self.dither > 0.0:
|
||||
waveform = waveform + self.dither * np.random.randn(*waveform.shape).astype(waveform.dtype)
|
||||
|
||||
if self.input_scale_factor != 1.0:
|
||||
waveform = waveform * self.input_scale_factor
|
||||
|
||||
# Semicausal time padding: prepend frame_length // 2 zeros so that the
|
||||
# first STFT frame is centered at t=0, matching sl.STFT(time_padding='semicausal').
|
||||
pad_left = self.frame_length // 2
|
||||
waveform = np.pad(waveform, ((0, 0), (pad_left, 0)), mode="constant")
|
||||
attention_mask = np.pad(attention_mask, (pad_left, 0), mode="constant", constant_values=0)
|
||||
|
||||
frame_size_for_unfold = self.frame_length + 1
|
||||
|
||||
# NumPy equivalent of unfold for [B, NumFrames, frame_size_for_unfold]
|
||||
frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=self.hop_length)
|
||||
|
||||
if self.preemphasis > 0.0:
|
||||
if self.preemphasis_htk_flavor:
|
||||
first_in_frame = frames_to_process[..., :1] * (1.0 - self.preemphasis)
|
||||
rest_in_frame = frames_to_process[..., 1:-1] - self.preemphasis * frames_to_process[..., :-2]
|
||||
frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1)
|
||||
else:
|
||||
frames = frames_to_process[..., 1:] - self.preemphasis * frames_to_process[..., :-1]
|
||||
else:
|
||||
frames = frames_to_process[..., :-1]
|
||||
|
||||
# Apply window, then RFFT. np.fft.rfft with n=fft_length implicitly
|
||||
# right-pads frames to fft_length.
|
||||
frames = frames * self.window # Broadcasting window
|
||||
stft = np.fft.rfft(frames, n=self.fft_length, axis=-1)
|
||||
|
||||
magnitude_spec = np.abs(stft)
|
||||
|
||||
mel_spec = np.matmul(magnitude_spec, self.mel_filters)
|
||||
log_mel_spec = np.log(mel_spec + self.mel_floor)
|
||||
|
||||
if self.per_bin_mean is not None:
|
||||
log_mel_spec = log_mel_spec - self.per_bin_mean # Broadcasting
|
||||
|
||||
if self.per_bin_stddev is not None:
|
||||
log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting
|
||||
|
||||
mel_spectrogram = log_mel_spec.squeeze(0)
|
||||
num_mel_frames = mel_spectrogram.shape[0]
|
||||
|
||||
# Build a frame-aware mask: a mel frame is valid only when every sample
|
||||
# in its analysis window [i*hop, i*hop + frame_size - 1] is real audio.
|
||||
# We check this by looking at the last sample of each frame's window.
|
||||
frame_end_indices = np.arange(num_mel_frames) * self.hop_length + frame_size_for_unfold - 1
|
||||
mask = attention_mask[frame_end_indices].astype(bool)
|
||||
return mel_spectrogram, mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
|
||||
padding: bool | str | PaddingStrategy = "longest",
|
||||
max_length: int | None = 480_000,
|
||||
truncation: bool = True,
|
||||
pad_to_multiple_of: int | None = 128,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
return_attention_mask: bool | None = True,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""Creates a batch of MEL spectrograms from the provided raw speech.
|
||||
|
||||
This implementation uses a different algorithm for windowing and preemphasis compared to the built-in
|
||||
`transformers.audio_utils.spectrogram()` function that _will_ result in different outputs. Consider this
|
||||
carefully when selecting an audio feature extractor, especially with pre-trained models.
|
||||
|
||||
Args:
|
||||
raw_speech:
|
||||
The audio for which MEL spectrograms are created.
|
||||
padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `"longest"`):
|
||||
The padding strategy to use for batches of audio with different lengths.
|
||||
max_length (`int`, *optional*, defaults to 480000):
|
||||
If provided, defines the maximum length of the audio to allow. Audio longer than this will be
|
||||
truncated if `truncation=True`.
|
||||
truncation (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to truncate audio above `max_length`.
|
||||
pad_to_multiple_of (`int`, *optional*, defaults to 128):
|
||||
When padding, pad to a multiple of this value. The default value is defined for optimal TPU support.
|
||||
return_tensors (`Union[str, TensorType]`, *optional*, defaults to `None`):
|
||||
The type of tensors to return (e.g., NumPy, or Torch).
|
||||
return_attention_mask (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the attention mask for the generated MEL spectrograms.
|
||||
"""
|
||||
|
||||
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
||||
is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence))
|
||||
is_batched = is_batched_numpy or is_batched_sequence
|
||||
|
||||
if is_batched:
|
||||
raw_speech = [np.asarray([rs]).T for rs in raw_speech]
|
||||
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
||||
raw_speech = np.asarray(raw_speech)
|
||||
|
||||
if not is_batched: # always return a batch
|
||||
raw_speech = [np.asarray([raw_speech])]
|
||||
|
||||
batched_speech = self.pad(
|
||||
BatchFeature({"input_features": raw_speech}),
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
)
|
||||
|
||||
prepared_speech = []
|
||||
prepared_speech_mask = []
|
||||
for speech, mask in zip(batched_speech.input_features, batched_speech.attention_mask):
|
||||
speech, mask = self._extract_spectrogram(speech.T, mask)
|
||||
prepared_speech.append(speech.astype(np.float32))
|
||||
prepared_speech_mask.append(mask)
|
||||
|
||||
prepared_speech = [speech * mask[..., None] for speech, mask in zip(prepared_speech, prepared_speech_mask)]
|
||||
|
||||
return BatchFeature(
|
||||
{"input_features": prepared_speech, "input_features_mask": prepared_speech_mask},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Gemma4AudioFeatureExtractor"]
|
||||
@@ -0,0 +1,220 @@
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
|
||||
from ...image_processing_backends import TorchvisionBackend
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, PILImageResampling
|
||||
from ...processing_utils import ImagesKwargs, Unpack
|
||||
from ...utils import TensorType, auto_docstring, logging
|
||||
from .image_processing_pil_gemma4 import _SUPPORTED_SOFT_TOKENS, get_aspect_ratio_preserving_size
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from transformers.models.siglip2.image_processing_siglip2.convert_image_to_patches
|
||||
def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor":
|
||||
"""
|
||||
Convert 3D tensor image of shape (num_channels, image_height, image_width) into 2D tensor of patches of shape
|
||||
(num_patches_height * num_patches_width, patch_size * patch_size * num_channels).
|
||||
"""
|
||||
num_channels, image_height, image_width = image.shape
|
||||
num_patches_height = image_height // patch_size
|
||||
num_patches_width = image_width // patch_size
|
||||
patched_image = image.reshape(num_channels, num_patches_height, patch_size, num_patches_width, patch_size)
|
||||
patched_image = patched_image.permute(1, 3, 2, 4, 0)
|
||||
patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1)
|
||||
return patched_image
|
||||
|
||||
|
||||
# Adopted from Siglip2 (mask -> position ids)
|
||||
def pad_along_first_dim(
|
||||
image: "torch.Tensor", positions: "torch.Tensor", target_length: int
|
||||
) -> tuple["torch.Tensor", "torch.Tensor"]:
|
||||
"""
|
||||
Pad the tensor along the first dimension.
|
||||
"""
|
||||
current_length = image.shape[0]
|
||||
padding_length = target_length - current_length
|
||||
if padding_length > 0:
|
||||
padding = [0, 0] * (image.ndim - 1) + [0, padding_length]
|
||||
pos_padding = (0, 0, 0, padding_length)
|
||||
image = torch.nn.functional.pad(image, padding, mode="constant", value=0)
|
||||
positions = torch.nn.functional.pad(positions, pos_padding, mode="constant", value=-1)
|
||||
return image, positions
|
||||
|
||||
|
||||
class Gemma4ImageProcessorKwargs(ImagesKwargs, total=False):
|
||||
"""
|
||||
patch_size (`int`, *optional*):
|
||||
Size of each image patch in pixels.
|
||||
max_soft_tokens (`int`, *optional*):
|
||||
Maximum number of soft (vision) tokens per image.
|
||||
Must be one of {70, 140, 280, 560, 1120}.
|
||||
pooling_kernel_size (`int`, *optional*):
|
||||
Spatial pooling kernel size applied after patchification.
|
||||
"""
|
||||
|
||||
patch_size: int
|
||||
max_soft_tokens: int
|
||||
pooling_kernel_size: int
|
||||
|
||||
|
||||
@auto_docstring(custom_intro="Constructs a Gemma4 image processor.")
|
||||
class Gemma4ImageProcessor(TorchvisionBackend):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = [0.0, 0.0, 0.0]
|
||||
image_std = [1.0, 1.0, 1.0]
|
||||
size = None
|
||||
default_to_square = True
|
||||
do_convert_rgb = True
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = False
|
||||
patch_size = 16
|
||||
max_soft_tokens = 280
|
||||
pooling_kernel_size = 3
|
||||
valid_kwargs = Gemma4ImageProcessorKwargs
|
||||
model_input_names = ["pixel_values", "image_position_ids", "num_soft_tokens_per_image"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Gemma4ImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
|
||||
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {self.max_soft_tokens}.")
|
||||
|
||||
def _validate_preprocess_kwargs(self, **kwargs):
|
||||
# Gemma4 uses aspect_ratio_preserving_resize driven by patch_size,
|
||||
# max_soft_tokens, and pooling_kernel_size — not the standard `size`
|
||||
# parameter. Temporarily disable do_resize so the base validation
|
||||
# doesn't require `size` to be set.
|
||||
kwargs["do_resize"] = False
|
||||
super()._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
def aspect_ratio_preserving_resize(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
patch_size: int,
|
||||
max_patches: int,
|
||||
pooling_kernel_size: int,
|
||||
resample: F.InterpolationMode,
|
||||
) -> torch.Tensor:
|
||||
height, width = image.shape[-2], image.shape[-1]
|
||||
target_height, target_width = get_aspect_ratio_preserving_size(
|
||||
height=height,
|
||||
width=width,
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
)
|
||||
|
||||
if target_height == height and target_width == width:
|
||||
return image
|
||||
|
||||
return F.resize(
|
||||
image,
|
||||
size=[target_height, target_width],
|
||||
interpolation=resample,
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
**kwargs: Unpack[Gemma4ImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
resample: "PILImageResampling | F.InterpolationMode | int | None",
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: float | list[float] | None,
|
||||
image_std: float | list[float] | None,
|
||||
return_tensors: str | TensorType | None,
|
||||
patch_size: int | None = None,
|
||||
max_soft_tokens: int | None = None,
|
||||
pooling_kernel_size: int | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
if max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
|
||||
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {max_soft_tokens}.")
|
||||
|
||||
# Compute max_patches from max_soft_tokens and pooling_kernel_size
|
||||
max_patches = max_soft_tokens * pooling_kernel_size**2
|
||||
|
||||
# Process each image individually: resize, rescale/normalize, patchify, pad.
|
||||
# Images have different aspect ratios and thus different resized dimensions,
|
||||
# so patchification and padding must happen per-image before stacking.
|
||||
pixel_values = []
|
||||
position_ids = []
|
||||
num_soft_tokens_per_image = []
|
||||
|
||||
for image in images:
|
||||
# Step 1: Aspect-ratio-preserving resize
|
||||
if do_resize:
|
||||
image = self.aspect_ratio_preserving_resize(
|
||||
image=image,
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
# Step 2: Rescale pixel values (typically to [0, 1]) and optionally identity normalize
|
||||
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
||||
|
||||
# Step 3: Patchify the image
|
||||
# (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels)
|
||||
patch_height = image.shape[-2] // patch_size
|
||||
patch_width = image.shape[-1] // patch_size
|
||||
patches = convert_image_to_patches(image, patch_size)
|
||||
num_soft_tokens_per_image.append(patches.shape[0] // pooling_kernel_size**2)
|
||||
|
||||
# Step 5: Compute position IDs
|
||||
device = image.device
|
||||
patch_grid = torch.meshgrid(
|
||||
torch.arange(patch_width, device=device),
|
||||
torch.arange(patch_height, device=device),
|
||||
indexing="xy",
|
||||
)
|
||||
stacked_grid = torch.stack(patch_grid, dim=-1)
|
||||
real_positions = stacked_grid.reshape(patches.shape[0], 2)
|
||||
|
||||
# Step 6. Pad pacthes and positions to `max_patches`
|
||||
patches, positions = pad_along_first_dim(patches, real_positions, max_patches)
|
||||
pixel_values.append(patches)
|
||||
position_ids.append(positions)
|
||||
|
||||
# Stack into batch tensors
|
||||
pixel_values = torch.stack(pixel_values, dim=0) # (batch, max_patches, patch_pixels)
|
||||
position_ids = torch.stack(position_ids, dim=0) # (batch, max_patches, 2)
|
||||
|
||||
data = {
|
||||
"pixel_values": pixel_values,
|
||||
"image_position_ids": position_ids,
|
||||
"num_soft_tokens_per_image": num_soft_tokens_per_image,
|
||||
}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["Gemma4ImageProcessor"]
|
||||
@@ -0,0 +1,278 @@
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_backends import PilBackend
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_transforms import resize
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ImagesKwargs, Unpack
|
||||
from ...utils import TensorType, auto_docstring, is_vision_available, logging
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_SUPPORTED_SOFT_TOKENS = (70, 140, 280, 560, 1120)
|
||||
|
||||
|
||||
def get_aspect_ratio_preserving_size(
|
||||
height: int,
|
||||
width: int,
|
||||
patch_size: int,
|
||||
max_patches: int,
|
||||
pooling_kernel_size: int,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Image is resized to preserve aspect ratio so it fits within the patch budget.
|
||||
Target dimensions are the largest that:
|
||||
1) Produce at most `max_patches` patches when patchified with `patch_size`
|
||||
2) Have height and width divisible by `pooling_kernel_size * patch_size`
|
||||
"""
|
||||
total_px = height * width
|
||||
target_px = max_patches * (patch_size**2)
|
||||
factor = math.sqrt(target_px / total_px)
|
||||
ideal_height = factor * height
|
||||
ideal_width = factor * width
|
||||
side_mult = pooling_kernel_size * patch_size
|
||||
|
||||
# Round down to nearest multiple of side_mult
|
||||
target_height = int(math.floor(ideal_height / side_mult)) * side_mult
|
||||
target_width = int(math.floor(ideal_width / side_mult)) * side_mult
|
||||
|
||||
# Handle edge cases where one or both dimensions round to 0
|
||||
if target_height == 0 and target_width == 0:
|
||||
raise ValueError(
|
||||
"Attempting to resize to a 0 x 0 image. Resized height should be divisble by "
|
||||
f"`pooling_kernel_size * patch_size`={pooling_kernel_size * patch_size}."
|
||||
)
|
||||
|
||||
max_side_length = (max_patches // pooling_kernel_size**2) * side_mult
|
||||
if target_height == 0:
|
||||
target_height = side_mult
|
||||
target_width = min(
|
||||
int(math.floor(width / height)) * side_mult,
|
||||
max_side_length,
|
||||
)
|
||||
elif target_width == 0:
|
||||
target_width = side_mult
|
||||
target_height = min(
|
||||
int(math.floor(height / width)) * side_mult,
|
||||
max_side_length,
|
||||
)
|
||||
|
||||
if target_height * target_width > target_px:
|
||||
raise ValueError(
|
||||
f"Resizing [{height}x{width}] to [{target_height}x{target_width}] "
|
||||
f"but this exceeds {max_patches} patches with patch_size {patch_size}"
|
||||
)
|
||||
|
||||
return target_height, target_width
|
||||
|
||||
|
||||
# Copied from transformers.models.siglip2.image_processing_pil_siglip2.convert_image_to_patches
|
||||
def convert_image_to_patches(image: np.ndarray, patch_size: int) -> np.ndarray:
|
||||
"""
|
||||
Convert 3D array image of shape (num_channels, image_height, image_width) into 2D array of patches of shape
|
||||
(num_patches_height * num_patches_width, patch_size * patch_size * num_channels).
|
||||
"""
|
||||
num_channels, image_height, image_width = image.shape
|
||||
num_patches_height = image_height // patch_size
|
||||
num_patches_width = image_width // patch_size
|
||||
patched_image = image.reshape(num_channels, num_patches_height, patch_size, num_patches_width, patch_size)
|
||||
patched_image = patched_image.transpose(1, 3, 2, 4, 0)
|
||||
patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1)
|
||||
return patched_image
|
||||
|
||||
|
||||
# Adopted from Siglip2 (mask -> position ids)
|
||||
def pad_along_first_dim(image: np.ndarray, positions: np.ndarray, target_length: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Pad the image along the first dimension.
|
||||
"""
|
||||
current_length = image.shape[0]
|
||||
padding_length = target_length - current_length
|
||||
if padding_length > 0:
|
||||
paddings = [(0, padding_length)] + [(0, 0)] * (image.ndim - 1)
|
||||
pos_paddings = [(0, padding_length), (0, 0)]
|
||||
image = np.pad(image, paddings, mode="constant", constant_values=0)
|
||||
positions = np.pad(positions, pos_paddings, mode="constant", constant_values=-1)
|
||||
return image, positions
|
||||
|
||||
|
||||
class Gemma4ImageProcessorKwargs(ImagesKwargs, total=False):
|
||||
"""
|
||||
patch_size (`int`, *optional*):
|
||||
Size of each image patch in pixels.
|
||||
max_soft_tokens (`int`, *optional*):
|
||||
Maximum number of soft (vision) tokens per image.
|
||||
Must be one of {70, 140, 280, 560, 1120}.
|
||||
pooling_kernel_size (`int`, *optional*):
|
||||
Spatial pooling kernel size applied after patchification.
|
||||
"""
|
||||
|
||||
patch_size: int
|
||||
max_soft_tokens: int
|
||||
pooling_kernel_size: int
|
||||
|
||||
|
||||
@auto_docstring(custom_intro="Constructs a Gemma4 image processor.")
|
||||
class Gemma4ImageProcessorPil(PilBackend):
|
||||
valid_kwargs = Gemma4ImageProcessorKwargs
|
||||
model_input_names = ["pixel_values", "image_position_ids", "num_soft_tokens_per_image"]
|
||||
|
||||
do_resize = True
|
||||
resample = PILImageResampling.BICUBIC
|
||||
do_rescale = True
|
||||
rescale_factor = 1 / 255
|
||||
do_normalize = False
|
||||
image_mean = [0.0, 0.0, 0.0]
|
||||
image_std = [1.0, 1.0, 1.0]
|
||||
do_convert_rgb = True
|
||||
patch_size = 16
|
||||
max_soft_tokens = 280
|
||||
pooling_kernel_size = 3
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Gemma4ImageProcessorKwargs]) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
|
||||
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {self.max_soft_tokens}.")
|
||||
|
||||
def _validate_preprocess_kwargs(self, **kwargs):
|
||||
# Gemma4 uses aspect_ratio_preserving_resize driven by patch_size,
|
||||
# max_soft_tokens, and pooling_kernel_size — not the standard `size`
|
||||
# parameter. Temporarily disable do_resize so the base validation
|
||||
# doesn't require `size` to be set.
|
||||
kwargs["do_resize"] = False
|
||||
super()._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
**kwargs: Unpack[Gemma4ImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def aspect_ratio_preserving_resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
patch_size: int,
|
||||
max_patches: int,
|
||||
pooling_kernel_size: int,
|
||||
resample: PILImageResampling,
|
||||
) -> np.ndarray:
|
||||
height, width = image.shape[-2], image.shape[-1]
|
||||
target_height, target_width = get_aspect_ratio_preserving_size(
|
||||
height=height,
|
||||
width=width,
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
)
|
||||
|
||||
if target_height == height and target_width == width:
|
||||
return image
|
||||
|
||||
return resize(
|
||||
image,
|
||||
size=(target_height, target_width),
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list[np.ndarray],
|
||||
do_resize: bool,
|
||||
resample: "PILImageResampling | int | None",
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: float | list[float] | None,
|
||||
image_std: float | list[float] | None,
|
||||
return_tensors: str | TensorType | None,
|
||||
max_soft_tokens: int | None = None,
|
||||
patch_size: int | None = None,
|
||||
pooling_kernel_size: int | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
if max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
|
||||
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {max_soft_tokens}.")
|
||||
|
||||
# Compute max_patches from max_soft_tokens and pooling_kernel_size
|
||||
max_patches = max_soft_tokens * pooling_kernel_size**2
|
||||
|
||||
# Process each image individually: resize, rescale/normalize, patchify, pad.
|
||||
# Images have different aspect ratios and thus different resized dimensions,
|
||||
# so patchification and padding must happen per-image before stacking.
|
||||
pixel_values = []
|
||||
position_ids = []
|
||||
num_soft_tokens_per_image = []
|
||||
|
||||
for image in images:
|
||||
# Step 1: Aspect-ratio-preserving resize
|
||||
if do_resize:
|
||||
image = self.aspect_ratio_preserving_resize(
|
||||
image=image,
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
# Step 2: Rescale pixel values from [0, 255] to [0, 1]
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor)
|
||||
|
||||
# Step 3: Identity normalization because Gemma4 was trained with pixels in [0, 1]
|
||||
if do_normalize:
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std)
|
||||
|
||||
# Step 4: Patchify the image
|
||||
# image is (C, H, W) numpy array; add batch dimension for reshape
|
||||
# (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels)
|
||||
patches = convert_image_to_patches(image, patch_size)
|
||||
num_soft_tokens_per_image.append(patches.shape[0] // pooling_kernel_size**2)
|
||||
|
||||
# Step 5: Compute position IDs
|
||||
patch_height = image.shape[-2] // patch_size
|
||||
patch_width = image.shape[-1] // patch_size
|
||||
grid_x, grid_y = np.meshgrid(np.arange(patch_width), np.arange(patch_height), indexing="xy")
|
||||
real_positions = np.stack([grid_x, grid_y], axis=-1).reshape(patches.shape[0], 2)
|
||||
|
||||
patches, positions = pad_along_first_dim(patches, real_positions, max_patches)
|
||||
|
||||
pixel_values.append(patches)
|
||||
position_ids.append(positions)
|
||||
|
||||
# Stack into batch arrays and convert to tensors
|
||||
pixel_values = np.stack(pixel_values, axis=0) # (batch, max_patches, patch_pixels)
|
||||
position_ids = np.stack(position_ids, axis=0) # (batch, max_patches, 2)
|
||||
|
||||
data = {
|
||||
"pixel_values": pixel_values,
|
||||
"image_position_ids": position_ids,
|
||||
"num_soft_tokens_per_image": num_soft_tokens_per_image,
|
||||
}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["Gemma4ImageProcessorPil"]
|
||||
@@ -0,0 +1,723 @@
|
||||
# === HEADER (license + imports) ===
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/gemma4/modular_gemma4.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_gemma4.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ... import initialization as init
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_experts_implementation, use_kernelized_func
|
||||
from ...masking_utils import (
|
||||
create_bidirectional_mask,
|
||||
create_causal_mask,
|
||||
create_masks_for_generate,
|
||||
create_sliding_window_causal_mask,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
TransformersKwargs,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_accelerate_available,
|
||||
torch_compilable_check,
|
||||
)
|
||||
from ...utils.generic import maybe_autocast, merge_with_config_defaults
|
||||
from ...utils.output_capturing import OutputRecorder, capture_outputs
|
||||
from ..auto.modeling_auto import AutoModel
|
||||
from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import add_hook_to_module
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Gemma4 outputs, with hidden states and attentions.
|
||||
"""
|
||||
)
|
||||
class Gemma4ModelOutputWithPast(BaseModelOutputWithPast):
|
||||
r"""
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
|
||||
# === CLASS/FUNCTION OUTLINE (signatures + short body) ===
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Gemma4 outputs, with hidden states and attentions.
|
||||
"""
|
||||
)
|
||||
class Gemma4ModelOutputWithPast(BaseModelOutputWithPast):
|
||||
r"""
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
...
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Gemma4 causal language model (or autoregressive) outputs.
|
||||
"""
|
||||
)
|
||||
class Gemma4CausalLMOutputWithPast(ModelOutput):
|
||||
r"""
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
...
|
||||
|
||||
@dataclass
|
||||
@auto_docstring
|
||||
class Gemma4AudioModelOutput(BaseModelOutputWithPooling):
|
||||
r"""
|
||||
attention_mask (`torch.BoolTensor`, *optional*):
|
||||
A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding.
|
||||
"""
|
||||
|
||||
attention_mask: torch.BoolTensor | None = None
|
||||
|
||||
|
||||
class Gemma4ClippableLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
...
|
||||
|
||||
class Gemma4RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.with_scale = with_scale
|
||||
|
||||
if self.with_scale:
|
||||
self.weight = nn.Parameter(torch.ones(dim), requires_grad=True)
|
||||
|
||||
def _norm(self, hidden_states: torch.Tensor):
|
||||
mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps
|
||||
# Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX
|
||||
return hidden_states * torch.pow(mean_squared, -0.5)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioRelPositionalEncoding(nn.Module):
|
||||
"""Sinusoidal relative positional encoding for the audio encoder.
|
||||
|
||||
Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with
|
||||
concatenated [sin..., cos...] layout matching the original Gemma4 convention.
|
||||
"""
|
||||
|
||||
inv_timescales: torch.Tensor
|
||||
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.context_size = (
|
||||
config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right
|
||||
...
|
||||
|
||||
class Gemma4AudioAttention(nn.Module):
|
||||
"""Chunked local attention with relative position bias"""
|
||||
|
||||
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.attention_logits_soft_cap = config.attention_logit_cap
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.q_scale = (self.head_dim**-0.5) / math.log(2)
|
||||
self.k_scale = math.log(1 + math.e) / math.log(2)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioSubSampleConvProjectionLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, norm_eps):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=1,
|
||||
bias=False,
|
||||
)
|
||||
self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False)
|
||||
self.act = nn.ReLU()
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioSubSampleConvProjection(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.layer0 = Gemma4AudioSubSampleConvProjectionLayer(
|
||||
in_channels=1,
|
||||
out_channels=config.subsampling_conv_channels[0],
|
||||
norm_eps=config.rms_norm_eps,
|
||||
)
|
||||
self.layer1 = Gemma4AudioSubSampleConvProjectionLayer(
|
||||
in_channels=config.subsampling_conv_channels[0],
|
||||
out_channels=config.subsampling_conv_channels[1],
|
||||
norm_eps=config.rms_norm_eps,
|
||||
)
|
||||
proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1]
|
||||
...
|
||||
|
||||
class Gemma4AudioFeedForward(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4)
|
||||
self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size)
|
||||
|
||||
self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size)
|
||||
self.post_layer_norm = Gemma4RMSNorm(config.hidden_size)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
self.gradient_clipping = config.gradient_clipping
|
||||
self.post_layer_scale = config.residual_weight
|
||||
...
|
||||
|
||||
class Gemma4AudioCausalConv1d(nn.Conv1d):
|
||||
# def __init__(
|
||||
# self,
|
||||
# in_channels: int,
|
||||
# out_channels: int,
|
||||
# kernel_size: int,
|
||||
# # cache_key: str,
|
||||
# stride: int = 1,
|
||||
# dilation: int = 1,
|
||||
# bias: bool = True,
|
||||
# ):
|
||||
# super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias)
|
||||
# self.cache_key = cache_key
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioLightConv1d(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2)
|
||||
self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
|
||||
self.depthwise_conv1d = Gemma4AudioCausalConv1d(
|
||||
in_channels=config.hidden_size,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=config.conv_kernel_size,
|
||||
groups=config.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
...
|
||||
|
||||
class Gemma4AudioLayer(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.feed_forward1 = Gemma4AudioFeedForward(config)
|
||||
self.feed_forward2 = Gemma4AudioFeedForward(config)
|
||||
self.self_attn = Gemma4AudioAttention(config, layer_idx)
|
||||
self.lconv1d = Gemma4AudioLightConv1d(config)
|
||||
|
||||
self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size)
|
||||
self.norm_post_attn = Gemma4RMSNorm(config.hidden_size)
|
||||
self.norm_out = Gemma4RMSNorm(config.hidden_size)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4VisionPatchEmbedder(nn.Module):
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.patch_size = config.patch_size
|
||||
self.position_embedding_size = config.position_embedding_size
|
||||
|
||||
self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False)
|
||||
self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size))
|
||||
|
||||
def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor:
|
||||
"""Prepare patch positions map for matmul with positon embedding table."""
|
||||
# Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul.
|
||||
...
|
||||
|
||||
class Gemma4VisionPooler(nn.Module):
|
||||
"""Scaling and optional spatial pooling for vision encodings"""
|
||||
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.root_hidden_size = self.hidden_size**0.5
|
||||
|
||||
def _avg_pool_by_positions(
|
||||
self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
2D spatial pooling according to patch positions.
|
||||
Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between
|
||||
...
|
||||
|
||||
class Gemma4VisionMLP(nn.Module):
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
|
||||
self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
|
||||
self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size)
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
...
|
||||
|
||||
class Gemma4VisionRotaryEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
||||
|
||||
def __init__(self, config: Gemma4VisionConfig, device=None):
|
||||
super().__init__()
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
|
||||
self.rope_type = self.config.rope_parameters["rope_type"]
|
||||
rope_init_fn: Callable = self.compute_default_rope_parameters
|
||||
if self.rope_type != "default":
|
||||
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
...
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): The tensor to embed.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
...
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
...
|
||||
|
||||
def apply_multidimensional_rope(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
unsqueeze_dim: int = 2,
|
||||
) -> torch.Tensor:
|
||||
"""Applies multidimensional RoPE to inputs.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): The tensor to embed.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
...
|
||||
|
||||
@use_kernelized_func(apply_rotary_pos_emb)
|
||||
class Gemma4VisionAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = 1.0
|
||||
self.attention_dropout = self.config.attention_dropout
|
||||
self.is_causal = False
|
||||
...
|
||||
|
||||
class Gemma4VisionEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma4VisionMLP(config)
|
||||
self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
...
|
||||
|
||||
class Gemma4VisionEncoder(nn.Module):
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.rotary_emb = Gemma4VisionRotaryEmbedding(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
...
|
||||
|
||||
class Gemma4TextMLP(nn.Module):
|
||||
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers
|
||||
is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
|
||||
use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
|
||||
...
|
||||
|
||||
class Gemma4TextRotaryEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None):
|
||||
super().__init__()
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.layer_types = set(config.layer_types)
|
||||
self.rope_init_fns: dict[str, Callable[..., tuple[torch.Tensor, float]]] = {}
|
||||
self.rope_type: dict[str, str] = {}
|
||||
|
||||
for layer_type in self.layer_types:
|
||||
...
|
||||
|
||||
@use_kernelized_func(apply_rotary_pos_emb)
|
||||
class Gemma4TextAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.is_sliding = self.layer_type == "sliding_attention"
|
||||
self.sliding_window = config.sliding_window if self.is_sliding else None
|
||||
|
||||
self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
|
||||
self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding
|
||||
...
|
||||
|
||||
@use_experts_implementation
|
||||
class Gemma4TextExperts(nn.Module):
|
||||
"""Collection of expert weights stored as 3D tensors."""
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.moe_intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
|
||||
def forward(
|
||||
...
|
||||
|
||||
class Gemma4TextRouter(nn.Module):
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.scalar_root_size = self.hidden_size**-0.5
|
||||
self.eps = config.rms_norm_eps
|
||||
|
||||
self.norm = Gemma4RMSNorm(self.hidden_size, eps=self.eps, with_scale=False)
|
||||
self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
||||
self.scale = nn.Parameter(torch.ones(self.hidden_size))
|
||||
self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
...
|
||||
|
||||
class Gemma4TextDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma4TextMLP(config, layer_idx)
|
||||
self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.register_buffer("layer_scalar", torch.ones(1))
|
||||
|
||||
...
|
||||
|
||||
class Gemma4TextScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.scalar_embed_scale = embed_scale
|
||||
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
|
||||
|
||||
|
||||
...
|
||||
|
||||
@auto_docstring
|
||||
class Gemma4PreTrainedModel(PreTrainedModel):
|
||||
config: Gemma4Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = None # override
|
||||
...
|
||||
|
||||
@auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.")
|
||||
class Gemma4TextModel(Gemma4PreTrainedModel):
|
||||
config: Gemma4TextConfig
|
||||
input_modalities = ("text",)
|
||||
_can_record_outputs = {
|
||||
"router_logits": OutputRecorder(Gemma4TextRouter, index=0),
|
||||
"hidden_states": Gemma4TextDecoderLayer,
|
||||
"attentions": Gemma4TextAttention,
|
||||
}
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
...
|
||||
|
||||
@auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.")
|
||||
class Gemma4ForCausalLM(Gemma4PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = {"lm_head": "colwise_gather_output"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
config: Gemma4TextConfig
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__(config)
|
||||
self.model = Gemma4TextModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
# Grab the ones from the child
|
||||
...
|
||||
|
||||
def sliding_window_mask_function(sliding_window: tuple[int, int]) -> Callable:
|
||||
"""
|
||||
This creates uni/bidirectional attention mask with sliding window.
|
||||
"""
|
||||
|
||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
left_window_size, right_window_size = sliding_window
|
||||
|
||||
dist = q_idx - kv_idx
|
||||
left_mask = (dist >= 0) & (dist < left_window_size)
|
||||
right_mask = (dist < 0) & (-dist < right_window_size)
|
||||
return left_mask | right_mask
|
||||
|
||||
return inner_mask
|
||||
...
|
||||
|
||||
class Gemma4AudioModel(Gemma4PreTrainedModel):
|
||||
"""An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture."""
|
||||
|
||||
config: Gemma4AudioConfig
|
||||
main_input_name = "input_features"
|
||||
base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained()
|
||||
_can_record_outputs = {
|
||||
"hidden_states": Gemma4AudioLayer,
|
||||
"attentions": Gemma4AudioAttention,
|
||||
}
|
||||
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
...
|
||||
|
||||
class Gemma4VisionModel(Gemma4PreTrainedModel):
|
||||
"""The Gemma 4 Vision Encoder."""
|
||||
|
||||
config = Gemma4VisionConfig
|
||||
_can_record_outputs = {
|
||||
"hidden_states": Gemma4VisionEncoderLayer,
|
||||
"attentions": Gemma4VisionAttention,
|
||||
}
|
||||
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__(config)
|
||||
self.patch_embedder = Gemma4VisionPatchEmbedder(config)
|
||||
self.encoder = Gemma4VisionEncoder(config)
|
||||
self.pooler = Gemma4VisionPooler(config)
|
||||
...
|
||||
|
||||
class Gemma4MultimodalEmbedder(nn.Module):
|
||||
"""Embeds token ids or soft tokens for multimodal content into language model space."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig,
|
||||
text_config: Gemma4TextConfig,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.multimodal_hidden_size = getattr(multimodal_config, "output_proj_dims", multimodal_config.hidden_size)
|
||||
self.eps = multimodal_config.rms_norm_eps
|
||||
self.text_hidden_size = text_config.hidden_size
|
||||
self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
|
||||
...
|
||||
|
||||
def token_type_ids_mask_function(
|
||||
token_type_ids: torch.Tensor | None,
|
||||
image_group_ids: torch.Tensor | None,
|
||||
) -> Callable | None:
|
||||
"""
|
||||
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
||||
not start and end indices.
|
||||
"""
|
||||
# Do not return an additional mask in this case
|
||||
if token_type_ids is None:
|
||||
return None
|
||||
|
||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
seq_length = image_group_ids.shape[-1]
|
||||
...
|
||||
|
||||
def create_causal_mask_mapping(
|
||||
config: PreTrainedConfig,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
past_key_values: Cache | None,
|
||||
position_ids: torch.Tensor | None,
|
||||
mm_token_type_ids: torch.Tensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
is_training: bool = False,
|
||||
is_first_iteration: bool | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
|
||||
...
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a
|
||||
language modeling head.
|
||||
"""
|
||||
)
|
||||
class Gemma4Model(Gemma4PreTrainedModel):
|
||||
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
||||
accepts_loss_kwargs = False
|
||||
|
||||
def __init__(self, config: Gemma4Config):
|
||||
super().__init__(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
|
||||
...
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling
|
||||
head.
|
||||
"""
|
||||
)
|
||||
class Gemma4ForConditionalGeneration(Gemma4PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
accepts_loss_kwargs = False
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: Gemma4Config):
|
||||
super().__init__(config)
|
||||
self.model = Gemma4Model(config)
|
||||
...
|
||||
|
||||
@@ -0,0 +1,563 @@
|
||||
# === HEADER (license + imports) ===
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ... import initialization as init
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...integrations import use_kernelized_func
|
||||
from ...masking_utils import (
|
||||
create_bidirectional_mask,
|
||||
create_causal_mask,
|
||||
create_masks_for_generate,
|
||||
create_sliding_window_causal_mask,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TransformersKwargs,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
torch_compilable_check,
|
||||
)
|
||||
from ...utils.generic import maybe_autocast, merge_with_config_defaults
|
||||
from ...utils.output_capturing import OutputRecorder, capture_outputs
|
||||
from ..auto.modeling_auto import AutoModel
|
||||
from ..gemma3.modeling_gemma3 import (
|
||||
Gemma3Attention,
|
||||
Gemma3DecoderLayer,
|
||||
Gemma3ForCausalLM,
|
||||
Gemma3MLP,
|
||||
Gemma3RotaryEmbedding,
|
||||
Gemma3TextModel,
|
||||
Gemma3TextScaledWordEmbedding,
|
||||
)
|
||||
from ..gemma3n.modeling_gemma3n import (
|
||||
Gemma3nCausalLMOutputWithPast,
|
||||
Gemma3nForConditionalGeneration,
|
||||
Gemma3nModel,
|
||||
Gemma3nModelOutputWithPast,
|
||||
Gemma3nMultimodalEmbedder,
|
||||
Gemma3nPreTrainedModel,
|
||||
Gemma3nRMSNorm,
|
||||
apply_rotary_pos_emb,
|
||||
eager_attention_forward,
|
||||
)
|
||||
from ..llama.modeling_llama import LlamaRotaryEmbedding
|
||||
from ..mixtral.modeling_mixtral import MixtralExperts
|
||||
from ..moonshine_streaming.modeling_moonshine_streaming import sliding_window_mask_function
|
||||
from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
pass
|
||||
|
||||
|
||||
|
||||
# === CLASS/FUNCTION OUTLINE (signatures + short body) ===
|
||||
class Gemma4ModelOutputWithPast(Gemma3nModelOutputWithPast):
|
||||
pass
|
||||
|
||||
|
||||
class Gemma4CausalLMOutputWithPast(Gemma3nCausalLMOutputWithPast):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring
|
||||
class Gemma4AudioModelOutput(BaseModelOutputWithPooling):
|
||||
r"""
|
||||
attention_mask (`torch.BoolTensor`, *optional*):
|
||||
A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding.
|
||||
...
|
||||
|
||||
class Gemma4ClippableLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma4VisionConfig | Gemma4AudioConfig,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.use_clipped_linears = config.use_clipped_linears
|
||||
self.linear = nn.Linear(in_features, out_features, bias=False)
|
||||
|
||||
if self.use_clipped_linears:
|
||||
self.register_buffer("input_min", torch.tensor(-float("inf")))
|
||||
self.register_buffer("input_max", torch.tensor(float("inf")))
|
||||
...
|
||||
|
||||
class Gemma4RMSNorm(Gemma3nRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class Gemma4AudioRelPositionalEncoding(nn.Module):
|
||||
"""Sinusoidal relative positional encoding for the audio encoder.
|
||||
|
||||
Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with
|
||||
concatenated [sin..., cos...] layout matching the original Gemma4 convention.
|
||||
"""
|
||||
|
||||
inv_timescales: torch.Tensor
|
||||
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
...
|
||||
|
||||
class Gemma4AudioAttention(nn.Module):
|
||||
"""Chunked local attention with relative position bias"""
|
||||
|
||||
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.attention_logits_soft_cap = config.attention_logit_cap
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.q_scale = (self.head_dim**-0.5) / math.log(2)
|
||||
self.k_scale = math.log(1 + math.e) / math.log(2)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioSubSampleConvProjectionLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, norm_eps):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=1,
|
||||
bias=False,
|
||||
)
|
||||
self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False)
|
||||
self.act = nn.ReLU()
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioSubSampleConvProjection(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.layer0 = Gemma4AudioSubSampleConvProjectionLayer(
|
||||
in_channels=1,
|
||||
out_channels=config.subsampling_conv_channels[0],
|
||||
norm_eps=config.rms_norm_eps,
|
||||
)
|
||||
self.layer1 = Gemma4AudioSubSampleConvProjectionLayer(
|
||||
in_channels=config.subsampling_conv_channels[0],
|
||||
out_channels=config.subsampling_conv_channels[1],
|
||||
norm_eps=config.rms_norm_eps,
|
||||
)
|
||||
proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1]
|
||||
...
|
||||
|
||||
class Gemma4AudioFeedForward(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4)
|
||||
self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size)
|
||||
|
||||
self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size)
|
||||
self.post_layer_norm = Gemma4RMSNorm(config.hidden_size)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
self.gradient_clipping = config.gradient_clipping
|
||||
self.post_layer_scale = config.residual_weight
|
||||
...
|
||||
|
||||
class Gemma4AudioCausalConv1d(nn.Conv1d):
|
||||
# def __init__(
|
||||
# self,
|
||||
# in_channels: int,
|
||||
# out_channels: int,
|
||||
# kernel_size: int,
|
||||
# # cache_key: str,
|
||||
# stride: int = 1,
|
||||
# dilation: int = 1,
|
||||
# bias: bool = True,
|
||||
# ):
|
||||
# super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias)
|
||||
# self.cache_key = cache_key
|
||||
|
||||
...
|
||||
|
||||
class Gemma4AudioLightConv1d(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2)
|
||||
self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
|
||||
self.depthwise_conv1d = Gemma4AudioCausalConv1d(
|
||||
in_channels=config.hidden_size,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=config.conv_kernel_size,
|
||||
groups=config.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
...
|
||||
|
||||
class Gemma4AudioLayer(nn.Module):
|
||||
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.feed_forward1 = Gemma4AudioFeedForward(config)
|
||||
self.feed_forward2 = Gemma4AudioFeedForward(config)
|
||||
self.self_attn = Gemma4AudioAttention(config, layer_idx)
|
||||
self.lconv1d = Gemma4AudioLightConv1d(config)
|
||||
|
||||
self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size)
|
||||
self.norm_post_attn = Gemma4RMSNorm(config.hidden_size)
|
||||
self.norm_out = Gemma4RMSNorm(config.hidden_size)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4VisionPatchEmbedder(nn.Module):
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.patch_size = config.patch_size
|
||||
self.position_embedding_size = config.position_embedding_size
|
||||
|
||||
self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False)
|
||||
self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size))
|
||||
|
||||
def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor:
|
||||
"""Prepare patch positions map for matmul with positon embedding table."""
|
||||
# Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul.
|
||||
...
|
||||
|
||||
class Gemma4VisionPooler(nn.Module):
|
||||
"""Scaling and optional spatial pooling for vision encodings"""
|
||||
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.root_hidden_size = self.hidden_size**0.5
|
||||
|
||||
def _avg_pool_by_positions(
|
||||
self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
2D spatial pooling according to patch positions.
|
||||
Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between
|
||||
...
|
||||
|
||||
class Gemma4VisionMLP(Gemma3MLP):
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__(self, config)
|
||||
self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
|
||||
self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
|
||||
self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size)
|
||||
|
||||
|
||||
def apply_multidimensional_rope(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
unsqueeze_dim: int = 2,
|
||||
...
|
||||
|
||||
class Gemma4VisionRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
@staticmethod
|
||||
def compute_default_rope_parameters(
|
||||
config: Gemma4VisionConfig | None = None,
|
||||
device: torch.device | None = None,
|
||||
seq_len: int | None = None,
|
||||
) -> tuple["torch.Tensor", float]:
|
||||
"""
|
||||
Computes the inverse frequencies according to the original RoPE implementation
|
||||
Args:
|
||||
config ([`~transformers.PreTrainedConfig`]):
|
||||
The model configuration.
|
||||
device (`torch.device`):
|
||||
The device to use for initialization of the inverse frequencies.
|
||||
...
|
||||
|
||||
class Gemma4VisionAttention(Gemma3Attention):
|
||||
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
|
||||
super().__init__(self, config, layer_idx)
|
||||
del self.attn_logit_softcapping
|
||||
del self.sliding_window
|
||||
del self.is_sliding
|
||||
self.scaling = 1.0
|
||||
self.is_causal = False
|
||||
self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
|
||||
self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_attention_heads * self.head_dim)
|
||||
self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
|
||||
self.o_proj = Gemma4ClippableLinear(config, config.num_attention_heads * self.head_dim, config.hidden_size)
|
||||
self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4VisionEncoderLayer(Gemma3DecoderLayer):
|
||||
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
|
||||
super().__init__(self, config, layer_idx)
|
||||
self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma4VisionMLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
||||
...
|
||||
|
||||
class Gemma4VisionEncoder(nn.Module):
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.rotary_emb = Gemma4VisionRotaryEmbedding(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
...
|
||||
|
||||
class Gemma4TextMLP(Gemma3MLP):
|
||||
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
|
||||
first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers
|
||||
is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
|
||||
use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
|
||||
super().__init__()
|
||||
self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)
|
||||
|
||||
|
||||
class Gemma4TextRotaryEmbedding(Gemma3RotaryEmbedding):
|
||||
def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None):
|
||||
nn.Module.__init__(self)
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
...
|
||||
|
||||
@use_kernelized_func(apply_rotary_pos_emb)
|
||||
class Gemma4TextAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.is_sliding = self.layer_type == "sliding_attention"
|
||||
self.sliding_window = config.sliding_window if self.is_sliding else None
|
||||
|
||||
self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
|
||||
self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding
|
||||
...
|
||||
|
||||
class Gemma4TextExperts(MixtralExperts):
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_experts
|
||||
self.intermediate_dim = config.moe_intermediate_size
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
|
||||
|
||||
class Gemma4TextRouter(nn.Module):
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.scalar_root_size = self.hidden_size**-0.5
|
||||
...
|
||||
|
||||
class Gemma4TextDecoderLayer(Gemma3DecoderLayer):
|
||||
def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma4TextMLP(config, layer_idx)
|
||||
self.register_buffer("layer_scalar", torch.ones(1))
|
||||
|
||||
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
|
||||
if self.hidden_size_per_layer_input:
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
|
||||
self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
|
||||
self.post_per_layer_input_norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
...
|
||||
|
||||
class Gemma4TextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
# ---- Model Classes ----
|
||||
|
||||
|
||||
class Gemma4PreTrainedModel(Gemma3nPreTrainedModel):
|
||||
_no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
|
||||
input_modalities = ("image", "text", "video", "audio")
|
||||
_can_record_outputs = None # override
|
||||
_skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
|
||||
|
||||
@torch.no_grad()
|
||||
...
|
||||
|
||||
@auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.")
|
||||
class Gemma4TextModel(Gemma3TextModel):
|
||||
config: Gemma4TextConfig
|
||||
_can_record_outputs = {
|
||||
"router_logits": OutputRecorder(Gemma4TextRouter, index=0),
|
||||
"hidden_states": Gemma4TextDecoderLayer,
|
||||
"attentions": Gemma4TextAttention,
|
||||
}
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[Gemma4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
...
|
||||
|
||||
@auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.")
|
||||
class Gemma4ForCausalLM(Gemma3ForCausalLM):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: Gemma4TextConfig):
|
||||
super().__init__(config)
|
||||
# Grab the ones from the child
|
||||
self._keys_to_ignore_on_load_unexpected = [
|
||||
f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
|
||||
]
|
||||
|
||||
|
||||
class Gemma4AudioModel(Gemma4PreTrainedModel):
|
||||
"""An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture."""
|
||||
...
|
||||
|
||||
class Gemma4VisionModel(Gemma4PreTrainedModel):
|
||||
"""The Gemma 4 Vision Encoder."""
|
||||
|
||||
config = Gemma4VisionConfig
|
||||
_can_record_outputs = {
|
||||
"hidden_states": Gemma4VisionEncoderLayer,
|
||||
"attentions": Gemma4VisionAttention,
|
||||
}
|
||||
|
||||
def __init__(self, config: Gemma4VisionConfig):
|
||||
super().__init__(config)
|
||||
self.patch_embedder = Gemma4VisionPatchEmbedder(config)
|
||||
self.encoder = Gemma4VisionEncoder(config)
|
||||
self.pooler = Gemma4VisionPooler(config)
|
||||
...
|
||||
|
||||
class Gemma4MultimodalEmbedder(Gemma3nMultimodalEmbedder):
|
||||
def __init__(
|
||||
self,
|
||||
multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig,
|
||||
text_config: Gemma4TextConfig,
|
||||
):
|
||||
# Audio tower may use a different output dimension (output_proj_dims) than the
|
||||
# internal hidden_size. Use the tower-specific dimension if specified.
|
||||
super().__init__(multimodal_config, text_config)
|
||||
del self.embedding
|
||||
del self.hard_embedding_norm
|
||||
del self.soft_embedding_norm
|
||||
del self.vocab_offset
|
||||
del self.vocab_size
|
||||
...
|
||||
|
||||
def token_type_ids_mask_function(
|
||||
token_type_ids: torch.Tensor | None,
|
||||
image_group_ids: torch.Tensor | None,
|
||||
) -> Callable | None:
|
||||
"""
|
||||
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
||||
not start and end indices.
|
||||
"""
|
||||
# Do not return an additional mask in this case
|
||||
if token_type_ids is None:
|
||||
return None
|
||||
|
||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
seq_length = image_group_ids.shape[-1]
|
||||
...
|
||||
|
||||
def create_causal_mask_mapping(
|
||||
config: PreTrainedConfig,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
past_key_values: Cache | None,
|
||||
position_ids: torch.Tensor | None,
|
||||
mm_token_type_ids: torch.Tensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
is_training: bool = False,
|
||||
is_first_iteration: bool | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
|
||||
...
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a
|
||||
language modeling head.
|
||||
"""
|
||||
)
|
||||
class Gemma4Model(Gemma3nModel):
|
||||
def __init__(self, config: Gemma4Config):
|
||||
super().__init__(config)
|
||||
del self.vision_tower
|
||||
del self.embed_vision
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config) if config.vision_config is not None else None
|
||||
self.embed_vision = (
|
||||
Gemma4MultimodalEmbedder(config.vision_config, config.text_config)
|
||||
...
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling
|
||||
head.
|
||||
"""
|
||||
)
|
||||
class Gemma4ForConditionalGeneration(Gemma3nForConditionalGeneration):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: Gemma4Config):
|
||||
super().__init__(config)
|
||||
# Grab the ones from the child
|
||||
self._keys_to_ignore_on_load_unexpected = [
|
||||
f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
|
||||
...
|
||||
|
||||
@@ -0,0 +1,366 @@
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...audio_utils import AudioInput
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, make_nested_list_of_images
|
||||
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import auto_docstring, is_vision_available, logging
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_utils import VideoInput
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from .image_processing_pil_gemma4 import Gemma4ImageProcessorKwargs, get_aspect_ratio_preserving_size
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma4ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
images_kwargs: Gemma4ImageProcessorKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": True,
|
||||
"return_mm_token_type_ids": True,
|
||||
},
|
||||
"images_kwargs": {
|
||||
"do_convert_rgb": True,
|
||||
},
|
||||
"audio_kwargs": {},
|
||||
"videos_kwargs": {"return_metadata": True},
|
||||
}
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@requires(backends=("vision",))
|
||||
class Gemma4Processor(ProcessorMixin):
|
||||
def __init__(
|
||||
self,
|
||||
feature_extractor,
|
||||
image_processor,
|
||||
tokenizer,
|
||||
video_processor,
|
||||
chat_template=None,
|
||||
image_seq_length: int = 280,
|
||||
audio_seq_length: int = 750,
|
||||
audio_ms_per_token: int = 40,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
image_seq_length (`int`, *optional*, defaults to 280):
|
||||
The number of soft tokens per image used for placeholder expansion.
|
||||
audio_seq_length (`int`, *optional*, defaults to 750):
|
||||
The maximum number of audio soft tokens per audio segment. Serves as an
|
||||
upper-bound cap when dynamic audio token counts are computed.
|
||||
audio_ms_per_token (`int`, *optional*, defaults to 40):
|
||||
Milliseconds of audio per output soft token. Used to dynamically compute
|
||||
the number of audio placeholder tokens as ``ceil(duration_ms / audio_ms_per_token)``.
|
||||
The default of 40 comes from the SSCP convolution's 4× time reduction on 10ms frames.
|
||||
"""
|
||||
self.image_seq_length = image_seq_length
|
||||
self.image_token_id = tokenizer.image_token_id
|
||||
self.boi_token = tokenizer.boi_token
|
||||
self.eoi_token = tokenizer.eoi_token
|
||||
self.image_token = tokenizer.image_token
|
||||
|
||||
# FIXME: add the token to config and ask Ryan to re-upload
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<|video|>"]})
|
||||
self.video_token = "<|video|>"
|
||||
self.video_token_id = tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
|
||||
# Audio token handling, mirroring the vision pattern.
|
||||
# audio_seq_length serves as the maximum cap on the number of audio soft tokens
|
||||
# any single audio segment can produce. With dynamic audio tokens, the actual
|
||||
# number of placeholders inserted per audio is computed from the audio duration.
|
||||
self.audio_seq_length = audio_seq_length
|
||||
# Milliseconds of audio per output soft token. The default of 40 comes from the
|
||||
# SSCP convolution's 4× time reduction applied to 10ms mel spectrogram frames.
|
||||
self.audio_ms_per_token = audio_ms_per_token
|
||||
self.audio_token_id = getattr(tokenizer, "audio_token_id", None)
|
||||
self.audio_token = getattr(tokenizer, "audio_token", None)
|
||||
self.boa_token = getattr(tokenizer, "boa_token", None)
|
||||
self.eoa_token = getattr(tokenizer, "eoa_token", None)
|
||||
|
||||
super().__init__(
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
tokenizer=tokenizer,
|
||||
video_processor=video_processor,
|
||||
chat_template=chat_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@auto_docstring
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput | None = None,
|
||||
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
||||
audio: AudioInput | None = None,
|
||||
videos: VideoInput | None = None,
|
||||
**kwargs: Unpack[Gemma4ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
if text is None and images is None and audio is None and videos is None:
|
||||
raise ValueError("Provide at least one of `text`, `images`, `audio`, or `videos`.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Gemma4ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
image_inputs = {}
|
||||
if images is not None:
|
||||
images = self.image_processor.fetch_images(images)
|
||||
batched_images = make_nested_list_of_images(images)
|
||||
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
|
||||
num_soft_tokens = image_inputs.pop("num_soft_tokens_per_image")
|
||||
|
||||
# Create empty text to be replaced with placeholders
|
||||
if not text:
|
||||
text = [" ".join([self.image_token] * len(images)) for images in batched_images]
|
||||
|
||||
if len(batched_images) != len(text):
|
||||
raise ValueError(
|
||||
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
|
||||
)
|
||||
|
||||
replacements = [f"{self.boi_token}{self.image_token * n}{self.eoi_token}" for n in num_soft_tokens]
|
||||
replacements_iter = iter(replacements)
|
||||
|
||||
# Expand image_token placeholders to per-image soft token sequences.
|
||||
# re.sub never re-scans replaced text, so it is safe
|
||||
pattern = re.escape(self.image_token)
|
||||
text = [re.sub(pattern, lambda _: next(replacements_iter), prompt) for prompt in text]
|
||||
|
||||
# Process video inputs in same way
|
||||
video_inputs = {}
|
||||
if videos is not None:
|
||||
video_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||
num_video_tokens = video_inputs.pop("num_soft_tokens_per_video")
|
||||
|
||||
# If user has not requested video metadata, pop it so it isn't returned
|
||||
if not kwargs.get("return_metadata"):
|
||||
video_metadata = video_inputs.pop("video_metadata")
|
||||
else:
|
||||
video_metadata = video_inputs["video_metadata"]
|
||||
|
||||
video_replacements = []
|
||||
for metadata, n_tokens in zip(video_metadata, num_video_tokens):
|
||||
if metadata.fps is None:
|
||||
logger.warning_once(
|
||||
"Gemma 4 requires frame timestamps to construct prompts, but the `fps` of the input video "
|
||||
"could not be inferred. Probably `video_metadata` was missing from inputs and you passed "
|
||||
"pre-sampled frames. Defaulting to `fps=24`. Please provide `video_metadata` for more "
|
||||
"accurate results."
|
||||
)
|
||||
metadata.fps = 24 if metadata.fps is None else metadata.fps
|
||||
# mm:ss format for timestamps
|
||||
timestamp_str = [
|
||||
f"{int(seconds // 60):02d}:{int(seconds % 60):02d}" for seconds in metadata.timestamps
|
||||
]
|
||||
video_replacements.append(
|
||||
" ".join(
|
||||
[f"{t} {self.boi_token}{self.video_token * n_tokens}{self.eoi_token}" for t in timestamp_str]
|
||||
)
|
||||
)
|
||||
|
||||
video_replacements = iter(video_replacements)
|
||||
pattern = re.escape(self.video_token)
|
||||
text = [re.sub(pattern, lambda _: next(video_replacements), prompt) for prompt in text]
|
||||
|
||||
# Process audio inputs
|
||||
audio_inputs = {}
|
||||
if audio is not None:
|
||||
if self.audio_token is None or self.boa_token is None or self.eoa_token is None:
|
||||
raise ValueError(
|
||||
"Audio inputs were provided, but the tokenizer does not have an `audio_token` defined."
|
||||
)
|
||||
|
||||
# Normalize audio input to list of waveforms
|
||||
if isinstance(audio, np.ndarray) and audio.ndim == 1:
|
||||
audio = [audio]
|
||||
|
||||
# TODO: Add tests for audio-only processor inputs.
|
||||
if not text:
|
||||
text = [self.audio_token] * len(audio)
|
||||
|
||||
# Dynamic audio token expansion wihtout padding:
|
||||
# * Extract audio features with feature extractor;
|
||||
# * Compute precise per-audio token counts from the waveform duration;
|
||||
# * Generate full audio token sequence for each computed audio length;
|
||||
# * Expand text prompts with full audio token sequences.
|
||||
audio_kwargs = output_kwargs.get("audio_kwargs", {})
|
||||
audio_inputs = self.feature_extractor(audio, **audio_kwargs)
|
||||
sampling_rate = self.feature_extractor.sampling_rate
|
||||
num_audio_tokens = [self._compute_audio_num_tokens(a, sampling_rate) for a in audio]
|
||||
replacements = [f"{self.boa_token}{self.audio_token * n}{self.eoa_token}" for n in num_audio_tokens]
|
||||
replacements_iter = iter(replacements)
|
||||
audio_pattern = re.escape(self.audio_token)
|
||||
text = [re.sub(audio_pattern, lambda _: next(replacements_iter), prompt) for prompt in text]
|
||||
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||||
|
||||
# Check special tokens for all active modalities
|
||||
active_modalities = []
|
||||
if images is not None:
|
||||
active_modalities.append("image")
|
||||
if videos is not None:
|
||||
active_modalities.append("video")
|
||||
if audio is not None:
|
||||
active_modalities.append("audio")
|
||||
if active_modalities:
|
||||
self._check_special_mm_tokens(text, text_inputs, modalities=active_modalities)
|
||||
|
||||
if return_mm_token_type_ids:
|
||||
text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
|
||||
|
||||
return BatchFeature(
|
||||
data={**text_inputs, **image_inputs, **audio_inputs, **video_inputs},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
def _compute_audio_num_tokens(self, audio_waveform, sampling_rate: int) -> int:
|
||||
"""Compute the number of audio soft tokens for a single waveform.
|
||||
|
||||
Replicates the exact sequence-length arithmetic of the audio encoder
|
||||
so that the processor inserts the correct number of placeholder tokens.
|
||||
The computation mirrors:
|
||||
|
||||
1. Mel framing via ``_unfold`` in ``Gemma4AudioFeatureExtractor``
|
||||
2. Two ``Conv2d`` subsampling layers in ``Gemma4AudioSubSampleConvProjection``
|
||||
(each: kernel=3, stride=2, semicausal padding top=1, bottom=1)
|
||||
|
||||
The result is capped at ``self.audio_seq_length`` (the configured maximum).
|
||||
|
||||
Args:
|
||||
audio_waveform: A 1-D numpy array or list containing the raw audio samples.
|
||||
sampling_rate: The sampling rate of the audio waveform in Hz.
|
||||
|
||||
Returns:
|
||||
The number of audio soft tokens to insert as placeholders.
|
||||
"""
|
||||
num_samples = len(audio_waveform)
|
||||
|
||||
# Step 1: Mel frames (matches feature_extraction_gemma4.py _unfold)
|
||||
frame_length = int(round(sampling_rate * 20.0 / 1000.0)) # 320 @ 16kHz
|
||||
hop_length = int(round(sampling_rate * 10.0 / 1000.0)) # 160 @ 16kHz
|
||||
frame_size_for_unfold = frame_length + 1 # 321
|
||||
|
||||
# The feature extractor prepends (frame_length // 2) zero samples as
|
||||
# semicausal time-padding before the unfold. We must include this to
|
||||
# match the actual number of mel frames it produces.
|
||||
pad_left = frame_length // 2 # 160 @ 16kHz
|
||||
padded_samples = num_samples + pad_left
|
||||
num_mel_frames = (padded_samples - frame_size_for_unfold) // hop_length + 1
|
||||
|
||||
if num_mel_frames <= 0:
|
||||
return 0
|
||||
|
||||
# Step 2: Two SSCP conv layers (kernel=3, stride=2, semicausal pad top=1, bottom=1)
|
||||
# Each layer: T_out = (T_in + pad_top + pad_bottom - kernel) // stride + 1
|
||||
t = num_mel_frames
|
||||
for _ in range(2):
|
||||
t_padded = t + 2 # pad_top=1, pad_bottom=1
|
||||
t = (t_padded - 3) // 2 + 1
|
||||
|
||||
# Cap at the configured maximum
|
||||
return min(t, self.audio_seq_length)
|
||||
|
||||
def _get_num_multimodal_tokens(self, image_sizes=None, audio_lengths=None, **kwargs):
|
||||
"""
|
||||
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||
|
||||
Args:
|
||||
image_sizes (`list[list[int]]`, *optional*):
|
||||
The input sizes formatted as (height, width) per each image.
|
||||
audio_lengths (`list[int]`, *optional*):
|
||||
The lengths of audio inputs in number of samples. Used to dynamically
|
||||
compute per-audio token counts.
|
||||
|
||||
Returns:
|
||||
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||
input modalities, along with other useful data.
|
||||
"""
|
||||
|
||||
images_kwargs = Gemma4ProcessorKwargs._defaults.get("images_kwargs", {})
|
||||
images_kwargs.update(kwargs)
|
||||
patch_size = images_kwargs.get("patch_size", None) or self.image_processor.patch_size
|
||||
pooling_kernel_size = (
|
||||
images_kwargs.get("pooling_kernel_size", None) or self.image_processor.pooling_kernel_size
|
||||
)
|
||||
max_soft_tokens = images_kwargs.get("max_soft_tokens", None) or self.image_processor.max_soft_tokens
|
||||
|
||||
max_patches = max_soft_tokens * pooling_kernel_size**2
|
||||
|
||||
vision_data = {}
|
||||
if image_sizes is not None:
|
||||
num_image_tokens = []
|
||||
for image_size in image_sizes:
|
||||
target_h, target_w = get_aspect_ratio_preserving_size(
|
||||
height=image_size[0],
|
||||
width=image_size[1],
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
)
|
||||
patch_height = target_h // patch_size
|
||||
patch_width = target_w // patch_size
|
||||
num_image_tokens.append(patch_height * patch_width // pooling_kernel_size**2)
|
||||
|
||||
num_image_patches = [1] * len(image_sizes)
|
||||
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||
|
||||
if audio_lengths is not None:
|
||||
# Dynamically compute per-audio token counts from sample lengths.
|
||||
# audio_lengths are in number of samples; assume default sampling rate.
|
||||
sampling_rate = getattr(self.feature_extractor, "sampling_rate", 16_000)
|
||||
num_audio_tokens = [
|
||||
self._compute_audio_num_tokens(np.zeros(length), sampling_rate) for length in audio_lengths
|
||||
]
|
||||
vision_data.update({"num_audio_tokens": num_audio_tokens})
|
||||
|
||||
return MultiModalData(**vision_data)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
model_input_names = super().model_input_names
|
||||
model_input_names = [
|
||||
name
|
||||
for name in model_input_names
|
||||
if name not in ["num_soft_tokens_per_image", "num_soft_tokens_per_video"]
|
||||
]
|
||||
|
||||
# Include audio feature extractor input names if available
|
||||
if self.feature_extractor is not None:
|
||||
feature_extractor_input_names = self.feature_extractor.model_input_names
|
||||
model_input_names.extend([name for name in feature_extractor_input_names if name not in model_input_names])
|
||||
|
||||
return model_input_names + ["mm_token_type_ids"]
|
||||
|
||||
|
||||
__all__ = ["Gemma4Processor"]
|
||||
@@ -0,0 +1,237 @@
|
||||
# Copyright 2026 the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...processing_utils import Unpack, VideosKwargs
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
from ...video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor
|
||||
from ...video_utils import VideoInput
|
||||
from .image_processing_gemma4 import _SUPPORTED_SOFT_TOKENS, get_aspect_ratio_preserving_size
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
elif is_torchvision_available():
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma4VideoProcessorKwargs(VideosKwargs, total=False):
|
||||
"""
|
||||
patch_size (`int`, *optional*):
|
||||
Size of each image patch in pixels.
|
||||
max_soft_tokens (`int`, *optional*):
|
||||
Maximum number of soft (vision) tokens per video frame.
|
||||
Must be one of {70, 140, 280, 560, 1120}.
|
||||
pooling_kernel_size (`int`, *optional*):
|
||||
Spatial pooling kernel size applied after patchification.
|
||||
"""
|
||||
|
||||
patch_size: int
|
||||
max_soft_tokens: int
|
||||
pooling_kernel_size: int
|
||||
|
||||
|
||||
def convert_video_to_patches(video: "torch.Tensor", patch_size: int) -> "torch.Tensor":
|
||||
"""
|
||||
Convert 4D tensor video of shape (num_frames, num_channels, height, width) into 3D tensor of patches of shape
|
||||
(num_frames, num_patches_height * num_patches_width, patch_size * patch_size * num_channels).
|
||||
"""
|
||||
num_frames, num_channels, height, width = video.shape
|
||||
num_patches_height = height // patch_size
|
||||
num_patches_width = width // patch_size
|
||||
patched_video = video.reshape(
|
||||
num_frames, num_channels, num_patches_height, patch_size, num_patches_width, patch_size
|
||||
)
|
||||
patched_video = patched_video.permute(0, 2, 4, 3, 5, 1)
|
||||
patched_video = patched_video.reshape(num_frames, num_patches_height * num_patches_width, -1)
|
||||
return patched_video
|
||||
|
||||
|
||||
def pad_to_max_patches(
|
||||
video: "torch.Tensor", positions: "torch.Tensor", target_length: int
|
||||
) -> tuple["torch.Tensor", "torch.Tensor"]:
|
||||
"""
|
||||
Pad the video along to max number of patches
|
||||
"""
|
||||
current_length = video.shape[1]
|
||||
padding_length = target_length - current_length
|
||||
if padding_length > 0:
|
||||
padding = [0, 0, 0, padding_length, 0, 0]
|
||||
pos_padding = (0, 0, 0, padding_length, 0, 0)
|
||||
video = torch.nn.functional.pad(video, padding, mode="constant", value=0)
|
||||
positions = torch.nn.functional.pad(positions, pos_padding, mode="constant", value=-1)
|
||||
return video, positions
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a Gemma4 video processor that samples frames from videos for use with the Gemma4 model.",
|
||||
BASE_VIDEO_PROCESSOR_DOCSTRING,
|
||||
)
|
||||
class Gemma4VideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = [0.0, 0.0, 0.0]
|
||||
image_std = [1.0, 1.0, 1.0]
|
||||
size = None
|
||||
default_to_square = True
|
||||
do_convert_rgb = True
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
num_frames = 32
|
||||
do_sample_frames = True
|
||||
patch_size = 16
|
||||
max_soft_tokens = 70
|
||||
pooling_kernel_size = 3
|
||||
valid_kwargs = Gemma4VideoProcessorKwargs
|
||||
model_input_names = ["pixel_values_videos", "video_position_ids"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Gemma4VideoProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
|
||||
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {self.max_soft_tokens}.")
|
||||
|
||||
def _validate_preprocess_kwargs(self, **kwargs):
|
||||
# Gemma4 uses aspect_ratio_preserving_resize driven by patch_size,
|
||||
# max_soft_tokens, and pooling_kernel_size — not the standard `size`
|
||||
# parameter. Temporarily disable do_resize so the base validation
|
||||
# doesn't require `size` to be set.
|
||||
kwargs["do_resize"] = False
|
||||
super()._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
def aspect_ratio_preserving_resize(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
patch_size: int,
|
||||
max_patches: int,
|
||||
pooling_kernel_size: int,
|
||||
resample: F.InterpolationMode,
|
||||
) -> torch.Tensor:
|
||||
height, width = video.shape[-2], video.shape[-1]
|
||||
target_height, target_width = get_aspect_ratio_preserving_size(
|
||||
height=height,
|
||||
width=width,
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
)
|
||||
|
||||
if target_height == height and target_width == width:
|
||||
return video
|
||||
|
||||
return F.resize(
|
||||
video,
|
||||
size=[target_height, target_width],
|
||||
interpolation=resample,
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
**kwargs: Unpack[Gemma4VideoProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
return super().preprocess(videos, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
videos: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
resample: "F.InterpolationMode | int | None",
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: float | list[float] | None,
|
||||
image_std: float | list[float] | None,
|
||||
return_tensors: str | TensorType | None,
|
||||
patch_size: int | None = None,
|
||||
max_soft_tokens: int | None = None,
|
||||
pooling_kernel_size: int | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
if max_soft_tokens not in _SUPPORTED_SOFT_TOKENS:
|
||||
raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {max_soft_tokens}.")
|
||||
|
||||
max_patches = max_soft_tokens * pooling_kernel_size**2
|
||||
|
||||
pixel_values = []
|
||||
position_ids = []
|
||||
num_soft_tokens_per_video = []
|
||||
num_frames = 1
|
||||
|
||||
for video in videos:
|
||||
if do_resize:
|
||||
video = self.aspect_ratio_preserving_resize(
|
||||
video=video,
|
||||
patch_size=patch_size,
|
||||
max_patches=max_patches,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
video = self.rescale_and_normalize(video, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
||||
|
||||
num_frames = video.shape[0]
|
||||
patch_height = video.shape[-2] // patch_size
|
||||
patch_width = video.shape[-1] // patch_size
|
||||
patches = convert_video_to_patches(video, patch_size)
|
||||
num_soft_tokens_per_video.append(patches.shape[1] // pooling_kernel_size**2)
|
||||
|
||||
device = video.device
|
||||
patch_grid = torch.meshgrid(
|
||||
torch.arange(patch_width, device=device),
|
||||
torch.arange(patch_height, device=device),
|
||||
indexing="xy",
|
||||
)
|
||||
stacked_grid = torch.stack(patch_grid, dim=-1)
|
||||
real_positions = stacked_grid.reshape(patches.shape[1], 2)
|
||||
real_positions = real_positions[None, ...].repeat(num_frames, 1, 1)
|
||||
|
||||
patches, positions = pad_to_max_patches(patches, real_positions, max_patches)
|
||||
pixel_values.append(patches)
|
||||
position_ids.append(positions)
|
||||
|
||||
# Stack into batch tensors
|
||||
pixel_values = torch.stack(pixel_values, dim=0) # (num_videos, num_frames, max_patches, patch_pixels)
|
||||
position_ids = torch.stack(position_ids, dim=0) # (num_videos, num_frames, max_patches, 2)
|
||||
|
||||
data = {
|
||||
"pixel_values_videos": pixel_values,
|
||||
"video_position_ids": position_ids,
|
||||
"num_soft_tokens_per_video": num_soft_tokens_per_video,
|
||||
}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["Gemma4VideoProcessor"]
|
||||
Reference in New Issue
Block a user