# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""ColModernVBERT: multimodal late-interaction retrieval model.

Combines SigLIP vision encoder + ModernBERT text encoder with a pixel
shuffle connector and ColBERT-style 128-dim per-token embeddings.

Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged
"""

from collections.abc import Iterable, Mapping, Sequence

import torch
from torch import nn
from transformers import BatchFeature

from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
    BaseDummyInputsBuilder,
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptIndexTargets,
    PromptReplacement,
    PromptUpdate,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.colmodernvbert import ColModernVBertConfig

from .interfaces import (
    MultiModalEmbeddings,
    SupportsLateInteraction,
    SupportsMultiModal,
)
from .interfaces_base import default_pooling_type
from .modernbert import ModernBertEmbeddings, ModernBertLayer
from .siglip import SiglipVisionModel
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix

# ---------------------------------------------------------------------------
# Connector: pixel shuffle + simple linear projection
# ---------------------------------------------------------------------------


class ColModernVBertConnector(nn.Module):
    """Pixel shuffle spatial reduction followed by a linear projection.

    Reduces the vision encoder's token count by ``factor^2`` via pixel-shuffle
    spatial rearrangement, then projects the concatenated channels to the text
    encoder's hidden size with a single bias-free linear layer.
    """

    def __init__(self, config: ColModernVBertConfig):
        super().__init__()
        self.pixel_shuffle_factor = config.pixel_shuffle_factor
        vision_hidden_size = config.vision_config.hidden_size
        input_size = vision_hidden_size * (self.pixel_shuffle_factor**2)
        output_size = config.hidden_size
        self.proj = nn.Linear(input_size, output_size, bias=False)

    def pixel_shuffle(self, features: torch.Tensor) -> torch.Tensor:
        """Spatial rearrangement that reduces seq length by factor^2."""
        batch_size, seq_length, hidden_size = features.shape
        height = width = int(seq_length**0.5)
        factor = self.pixel_shuffle_factor

        # Reshape to (B, H, W, C)
        features = features.view(batch_size, height, width, hidden_size)

        # Reshape to (B, H/f, f, W/f, f, C)
        features = features.view(
            batch_size, height // factor, factor, width // factor, factor, hidden_size
        )

        # Permute to (B, H/f, W/f, f, f, C)
        features = features.permute(0, 1, 3, 2, 4, 5)

        # Reshape to (B, H/f, W/f, C * f^2)
        new_hidden_size = hidden_size * (factor**2)
        features = features.reshape(
            batch_size, height // factor, width // factor, new_hidden_size
        )

        return features

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        features = self.pixel_shuffle(features)
        batch_size = features.shape[0]
        features = features.reshape(batch_size, -1, features.shape[-1])
        return self.proj(features)


# ---------------------------------------------------------------------------
# Multimodal processing
# ---------------------------------------------------------------------------


class ColModernVBertProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> ColModernVBertConfig:
        return self.ctx.get_hf_config(ColModernVBertConfig)

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"image": None}

    def get_image_size_with_most_features(self) -> ImageSize:
        config = self.get_hf_config()
        size = config.vision_config.image_size
        return ImageSize(width=size, height=size)

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        return self.get_hf_config().image_seq_len


class ColModernVBertDummyInputsBuilder(
    BaseDummyInputsBuilder[ColModernVBertProcessingInfo],
):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Mapping[str, BaseDummyOptions],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        target_width, target_height = self.info.get_image_size_with_most_features()
        image_overrides = mm_options.get("image")

        return {
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
        }


class ColModernVBertMultiModalProcessor(
    BaseMultiModalProcessor[ColModernVBertProcessingInfo],
):
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()
        text_encoding = tokenizer(
            prompt,
            return_tensors="pt",
            **tok_kwargs,
        )
        result = BatchFeature(data=dict(text_encoding))

        images = mm_data.get("images")
        if images:
            from transformers import Idefics3ImageProcessor

            image_processor = Idefics3ImageProcessor.from_pretrained(
                self.info.ctx.model_config.model,
                revision=self.info.ctx.model_config.revision,
            )
            image_outputs = image_processor(
                images=images,
                do_image_splitting=False,
                return_tensors="pt",
            )
            result.update(image_outputs)

        return result

    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        return False

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        config = self.info.get_hf_config()
        image_token_id = config.image_token_id
        num_tokens = config.image_seq_len

        def get_replacement(item_idx: int):
            return [image_token_id] * num_tokens

        return [
            PromptReplacement(
                modality="image",
                target=PromptIndexTargets.start(),
                replacement=get_replacement,
            ),
        ]


# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------


@MULTIMODAL_REGISTRY.register_processor(
    ColModernVBertMultiModalProcessor,
    info=ColModernVBertProcessingInfo,
    dummy_inputs=ColModernVBertDummyInputsBuilder,
)
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColModernVBertForRetrieval(
    nn.Module, SupportsMultiModal, SupportsLateInteraction
):
    """ColModernVBERT multimodal late-interaction retrieval model.

    Architecture:
        Image -> SiglipVisionModel -> ColModernVBertConnector
                                                   ↓
        Text  -> ModernBertEmbeddings → [merge] → ModernBertLayers → norm
                                                                      ↓
                                              custom_text_proj → L2 norm
                                                   ↓
                                          per-token 128-d embeddings
    """

    is_pooling_model = True

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: ColModernVBertConfig = vllm_config.model_config.hf_config
        self.config = config
        text_config = config.text_config
        quant_config = vllm_config.quant_config

        # --- Vision encoder (reuses SiglipVisionModel from siglip.py) ---
        self.vision_model = SiglipVisionModel(
            config.vision_config,
            quant_config,
            prefix=maybe_prefix(prefix, "vision_model"),
        )

        # --- Connector (pixel shuffle + linear projection) ---
        self.connector = ColModernVBertConnector(config)

        # --- Text encoder (built from ModernBERT components directly) ---
        # We build the components individually rather than wrapping
        # ``ModernBertModel`` because ``ModernBertEncoderLayer`` reads
        # ``vllm_config.model_config.hf_config`` which would be
        # ``ColModernVBertConfig``, not ``ModernBertConfig``.
        self.text_embeddings = ModernBertEmbeddings(text_config)
        self.text_layers = nn.ModuleList(
            [
                ModernBertLayer(
                    config=text_config,
                    layer_id=i,
                    prefix=f"{prefix}.text_layers.{i}",
                )
                for i in range(text_config.num_hidden_layers)
            ]
        )
        self.text_final_norm = nn.LayerNorm(
            text_config.hidden_size,
            eps=text_config.norm_eps,
            bias=text_config.norm_bias,
        )

        # --- ColBERT projection (768 -> 128, with bias) ---
        self.custom_text_proj = nn.Linear(
            text_config.hidden_size,
            config.embedding_dim,
            bias=True,
            dtype=vllm_config.model_config.head_dtype,
        )

        # --- Pooler (applies projection + L2 normalize) ---
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler = pooler_for_token_embed(
            pooler_config,
            projector=self.custom_text_proj,
        )

    # ---- multimodal ---------------------------------------------------------

    def _get_image_features(
        self,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
        # Idefics3ImageProcessor may return (batch, tiles, C, H, W);
        # flatten to (batch*tiles, C, H, W) for SiglipVisionModel.
        if pixel_values.dim() == 5:
            b, t, c, h, w = pixel_values.shape
            pixel_values = pixel_values.reshape(b * t, c, h, w)
        vision_outputs = self.vision_model(
            pixel_values.to(dtype=self.vision_model.dtype),
        )
        return self.connector(vision_outputs)

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return []
        assert isinstance(pixel_values, torch.Tensor)
        image_features = self._get_image_features(pixel_values)
        return list(image_features)

    # ---- forward ------------------------------------------------------------

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor:
        hidden_states = self.text_embeddings(input_ids, inputs_embeds=inputs_embeds)

        for layer in self.text_layers:
            hidden_states = layer(hidden_states, positions)

        return self.text_final_norm(hidden_states)

    # ---- weight loading -----------------------------------------------------

    # Checkpoint prefix → vLLM param prefix.
    # More-specific prefixes must appear before shorter ones.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.text_model.layers.": "text_layers.",
            "model.text_model.embeddings.": "text_embeddings.",
            "model.text_model.final_norm.": "text_final_norm.",
            "model.connector.modality_projection.": "connector.",
            "model.custom_text_proj.": "custom_text_proj.",
            "model.vision_model.": "vision_model.vision_model.",
            "model.": "",
        },
    )

    # Checkpoint names for DecoupledEmbedding parts
    _BASE_EMB = "model.text_model.embeddings.tok_embeddings.weight"
    _EXTRA_EMB = (
        "model.text_model.embeddings.tok_embeddings.additional_embedding.weight"
    )

    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
        # DecoupledEmbedding requires concatenating base + additional
        # embedding tensors before loading, so we extract them first.
        base_embedding_weight: torch.Tensor | None = None
        additional_embedding_weight: torch.Tensor | None = None
        remaining: list[tuple[str, torch.Tensor]] = []

        for name, tensor in weights:
            if name == self._BASE_EMB:
                base_embedding_weight = tensor
            elif name == self._EXTRA_EMB:
                additional_embedding_weight = tensor
            else:
                remaining.append((name, tensor))

        # Load all non-embedding weights via AutoWeightsLoader
        loader = AutoWeightsLoader(self)
        loaded_params = loader.load_weights(
            remaining,
            mapper=self.hf_to_vllm_mapper,
        )

        # Concatenate and load DecoupledEmbedding weights
        if base_embedding_weight is not None:
            combined = base_embedding_weight
            if additional_embedding_weight is not None:
                combined = torch.cat(
                    [base_embedding_weight, additional_embedding_weight],
                    dim=0,
                )
            param_name = "text_embeddings.tok_embeddings.weight"
            params_dict = dict(self.named_parameters())
            if param_name in params_dict:
                param = params_dict[param_name]
                weight_loader = getattr(
                    param,
                    "weight_loader",
                    default_weight_loader,
                )
                weight_loader(param, combined)
                loaded_params.add(param_name)
        elif additional_embedding_weight is not None:
            raise ValueError(
                "Found 'text_model.embeddings.tok_embeddings"
                ".additional_embedding.weight' but not "
                "'text_model.embeddings.tok_embeddings.weight'"
            )

        # The pooler wraps ``custom_text_proj`` as its head projector.
        # Mark those params as loaded under the pooler path too.
        if hasattr(self, "pooler") and hasattr(self.pooler, "head"):
            head = self.pooler.head
            projector = getattr(head, "projector", None)
            if projector is not None and isinstance(projector, nn.Module):
                for pname, _ in projector.named_parameters():
                    loaded_params.add(f"pooler.head.projector.{pname}")

        return loaded_params
