# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright 2025 The vLLM team. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Gemma 4 model implementation for vLLM.""" from collections.abc import Iterable from dataclasses import replace from itertools import islice import regex as re import torch from torch import nn from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) from vllm.sequence import IntermediateTensors from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata from .interfaces import ( EagleModelMixin, MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP, ) from .utils import ( AutoWeightsLoader, WeightsMapper, extract_layer_index, is_pp_missing_parameter, make_layers, maybe_prefix, ) logger = init_logger(__name__) def _get_text_config(config): """Dereference text_config if config is a nested Gemma4Config. Gemma4 checkpoints use architectures=["Gemma4ForConditionalGeneration"] which yields a Gemma4Config with nested text_config. This function transparently returns the text config regardless of nesting. """ if hasattr(config, "text_config"): return config.text_config