#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/chmv2/modular_chmv2.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_chmv2.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from torch import nn

from ... import initialization as init
from ...backbone_utils import load_backbone
from ...modeling_outputs import DepthEstimatorOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from .configuration_chmv2 import CHMv2Config


def _get_backbone_hidden_size(config):
    if config.backbone_config is not None and hasattr(config.backbone_config, "hidden_size"):
        return config.backbone_config.hidden_size
    else:
        return config.hidden_size


class CHMv2ReassembleLayer(nn.Module):
    def __init__(self, config: CHMv2Config, channels: int, factor: int):
        super().__init__()
        # projection
        hidden_size = _get_backbone_hidden_size(config)
        self.projection = nn.Conv2d(in_channels=hidden_size, out_channels=channels, kernel_size=1)

        # up/down sampling depending on factor
        if factor > 1:
            self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
        elif factor == 1:
            self.resize = nn.Identity()
        elif factor < 1:
            # so should downsample
            self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1)

    def forward(self, hidden_state):
        hidden_state = self.projection(hidden_state)
        hidden_state = self.resize(hidden_state)
        return hidden_state


class CHMv2ReassembleStage(nn.Module):
    """
    Reassemble stage that processes hidden states from the backbone into image-like feature
    representations at various resolutions.
    """

    def __init__(self, config: CHMv2Config):
        super().__init__()
        self.config = config
        self.readout_type = config.readout_type

        self.layers = nn.ModuleList()
        for out_channels, factor in zip(config.post_process_channels, config.reassemble_factors):
            self.layers.append(
                CHMv2ReassembleLayer(
                    config=config,
                    channels=out_channels,
                    factor=factor,
                )
            )

        hidden_size = _get_backbone_hidden_size(config)
        if self.readout_type == "project":
            self.readout_projects = nn.ModuleList()
            for _ in range(len(self.layers)):
                self.readout_projects.append(nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), nn.GELU()))

    def forward(self, hidden_states: list[torch.Tensor], patch_height=None, patch_width=None) -> list[torch.Tensor]:
        out = []

        for layer_idx, hidden_state in enumerate(hidden_states):
            if isinstance(hidden_state, (tuple, list)) and len(hidden_state) == 2:
                hidden_state, cls_token = hidden_state[0], hidden_state[1]
                feature_shape = hidden_state.shape

                if self.readout_type == "project":
                    hidden_state = hidden_state.flatten(2).transpose(1, 2)
                    readout = cls_token.unsqueeze(1).expand_as(hidden_state)
                    hidden_state = self.readout_projects[layer_idx](torch.cat((hidden_state, readout), -1))
                    hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape)
                elif self.readout_type == "add":
                    hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1)
                    hidden_state = hidden_state.reshape(feature_shape)
            else:
                if hidden_state.dim() == 3:
                    hidden_state = hidden_state[:, 1:]
                    batch_size, _, num_channels = hidden_state.shape
                    hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
                    hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()

            hidden_state = self.layers[layer_idx](hidden_state)
            out.append(hidden_state)

        return out


class CHMv2PreActResidualLayer(nn.Module):
    """
    ResidualConvUnit, pre-activate residual unit.

    Args:
        config (`[CHMv2Config]`):
            Model configuration class defining the model architecture.
    """

    def __init__(self, config):
        super().__init__()

        self.activation1 = nn.ReLU()
        self.convolution1 = nn.Conv2d(
            config.fusion_hidden_size,
            config.fusion_hidden_size,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=True,
        )

        self.activation2 = nn.ReLU()
        self.convolution2 = nn.Conv2d(
            config.fusion_hidden_size,
            config.fusion_hidden_size,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=True,
        )

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        residual = hidden_state
        hidden_state = self.activation1(hidden_state)
        hidden_state = self.convolution1(hidden_state)
        hidden_state = self.activation2(hidden_state)
        hidden_state = self.convolution2(hidden_state)

        return hidden_state + residual


class CHMv2FeatureFusionLayer(nn.Module):
    def __init__(self, config: CHMv2Config, is_first_layer: bool = False):
        super().__init__()
        self.is_first_layer = is_first_layer

        self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)

        if not is_first_layer:
            self.residual_layer1 = CHMv2PreActResidualLayer(config)

        self.residual_layer2 = CHMv2PreActResidualLayer(config)

    def forward(self, hidden_state, residual=None, size=None):
        if residual is not None and not self.is_first_layer:
            if hidden_state.shape != residual.shape:
                _, _, height, width = hidden_state.shape
                residual = nn.functional.interpolate(
                    residual, size=(height, width), mode="bilinear", align_corners=False
                )
            hidden_state = hidden_state + self.residual_layer1(residual)

        hidden_state = self.residual_layer2(hidden_state)

        modifier = {"scale_factor": 2} if size is None else {"size": size}

        hidden_state = nn.functional.interpolate(
            hidden_state,
            **modifier,
            mode="bilinear",
            align_corners=True,
        )

        hidden_state = self.projection(hidden_state)

        return hidden_state


class CHMv2UpsampleConvHead(nn.Module):
    """
    Convolutional head with intermediate upsampling.

    Architecture: Conv3x3 -> 2x bilinear upsample -> Conv3x3 -> ReLU -> Conv1x1.
    """

    def __init__(self, features, number_output_channels, n_hidden_channels=128):
        super().__init__()
        self.head = nn.ModuleList(
            [
                nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
                nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
                nn.Conv2d(features // 2, n_hidden_channels, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Conv2d(n_hidden_channels, number_output_channels, kernel_size=1, stride=1, padding=0),
            ]
        )

    def forward(self, hidden_states):
        for layer in self.head:
            hidden_states = layer(hidden_states)
        return hidden_states


class CHMv2Head(nn.Module):
    """
    CHMv2 dense-prediction head adapted from DPT.

    Integrates reassemble, projection convs, feature fusion, and UpConv depth head.
    """

    def __init__(self, config: CHMv2Config):
        super().__init__()
        self.config = config

        self.reassemble_stage = CHMv2ReassembleStage(config)

        self.convs = nn.ModuleList()
        for channel in config.post_process_channels:
            self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))

        self.fusion_layers = nn.ModuleList()
        for idx in range(len(config.post_process_channels)):
            self.fusion_layers.append(CHMv2FeatureFusionLayer(config, is_first_layer=(idx == 0)))

        self.conv_depth = CHMv2UpsampleConvHead(
            features=config.fusion_hidden_size,
            number_output_channels=config.number_output_channels,
            n_hidden_channels=config.head_hidden_size,
        )

    def forward_features(self, hidden_states: list[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor:
        hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)

        features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
        features.reverse()

        fused_hidden_state = self.fusion_layers[0](features[0])
        for i in range(1, len(self.fusion_layers)):
            fused_hidden_state = self.fusion_layers[i](fused_hidden_state, features[i])

        return fused_hidden_state

    def forward(self, hidden_states: list[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor:
        out = self.forward_features(hidden_states, patch_height, patch_width)
        out = self.conv_depth(out)
        return out


class CHMv2FeaturesToDepth(nn.Module):
    """Converts raw logits from the CHMv2 head into a depth map using depth bins."""

    def __init__(self, config: CHMv2Config):
        super().__init__()
        self.min_depth = config.min_depth
        self.max_depth = config.max_depth
        self.bins_strategy = config.bins_strategy
        self.norm_strategy = config.norm_strategy
        self._mixlog_max_clamp_value = 1e-4
        self._mixlog_eps_shift = 1e-8
        self._mixlog_eps = 1e-12

    def _create_mixlog_bins(self, n_bins: int, device: torch.device) -> torch.Tensor:
        """
        Creates mixed log bins interpolated between linear and log distributions.

        The max_depth is divided by 8.0 internally; this scaling is reversed in
        `_create_outputs_with_mixlog_norm` by multiplying by 8.0.
        """
        scaled_max_depth = self.max_depth / 8.0
        linear = torch.linspace(self.min_depth, scaled_max_depth, n_bins, device=device)
        log = torch.exp(
            torch.linspace(
                torch.log(torch.tensor(self.min_depth, device=device)),
                torch.log(torch.tensor(scaled_max_depth, device=device)),
                n_bins,
                device=device,
            )
        )
        interp_weight = torch.linspace(1.0, 0.0, n_bins, device=device)
        bins = interp_weight * log + (1.0 - interp_weight) * linear
        return bins

    def _create_outputs_with_mixlog_norm(self, input: torch.Tensor, bins: torch.Tensor) -> torch.Tensor:
        """Converts depth bin logits to depth values using mixlog normalization."""
        logits = torch.relu(input)

        min_per_sample = logits.amin(dim=1, keepdim=True)
        shift = (-min_per_sample).clamp_min(0.0).clamp_max(self._mixlog_max_clamp_value) + self._mixlog_eps_shift
        logits_pos = logits + shift

        denom = logits_pos.sum(dim=1, keepdim=True)
        denom = torch.nan_to_num(denom, nan=1.0, posinf=1.0, neginf=1.0).clamp_min(self._mixlog_eps)
        weights = logits_pos / denom

        bins_broadcast = bins.view(1, -1, 1, 1).clamp_min(self._mixlog_eps)
        output = (weights * bins_broadcast).sum(dim=1, keepdim=True).clamp_min(self._mixlog_eps)

        output = output * 8.0

        return output

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        n_bins = x.shape[1]

        if n_bins > 1:
            if self.bins_strategy == "linear":
                bins = torch.linspace(self.min_depth, self.max_depth, n_bins, device=x.device)
            elif self.bins_strategy == "log":
                bins = torch.linspace(
                    torch.log(torch.tensor(self.min_depth)),
                    torch.log(torch.tensor(self.max_depth)),
                    n_bins,
                    device=x.device,
                )
                bins = torch.exp(bins)
            else:
                bins = self._create_mixlog_bins(n_bins, x.device)

            if self.norm_strategy in ["linear", "softmax", "sigmoid"]:
                if self.norm_strategy == "linear":
                    logit = torch.relu(x)
                    eps = 0.1
                    logit = logit + eps
                    logit = logit / logit.sum(dim=1, keepdim=True)
                elif self.norm_strategy == "softmax":
                    logit = torch.softmax(x, dim=1)
                else:
                    logit = torch.sigmoid(x)
                    logit = logit / logit.sum(dim=1, keepdim=True)
                output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
            else:
                output = self._create_outputs_with_mixlog_norm(x, bins)
        else:
            output = torch.relu(x) + self.min_depth

        return output


@auto_docstring
class CHMv2PreTrainedModel(PreTrainedModel):
    config: CHMv2Config
    base_model_prefix = "chmv2"
    main_input_name = "pixel_values"
    input_modalities = ("image",)
    supports_gradient_checkpointing = True
    _supports_sdpa = True
    _supports_flash_attn = True
    _supports_flex_attn = True
    _supports_attention_backend = True

    def _init_weights(self, module) -> None:
        super()._init_weights(module)
        if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
            init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                init.zeros_(module.bias)


@auto_docstring(
    custom_intro="""
    CHMv2 Model with a depth estimation head on top (consisting of convolutional layers) e.g. for canopy height
    estimation.
    """
)
class CHMv2ForDepthEstimation(CHMv2PreTrainedModel):
    def __init__(self, config: CHMv2Config):
        super().__init__(config)

        self.backbone = load_backbone(config)
        self.head = CHMv2Head(config)
        self.features_to_depth = CHMv2FeaturesToDepth(config)

        self.post_init()

    def get_input_embeddings(self):
        return self.backbone.get_input_embeddings()

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        pixel_values: torch.FloatTensor,
        labels: torch.LongTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> DepthEstimatorOutput:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth depth estimation maps for computing the loss.
        """
        loss = None
        if labels is not None:
            raise NotImplementedError("Training is not implemented yet")

        _, _, height, width = pixel_values.shape
        patch_size = self.config.patch_size
        patch_height = height // patch_size
        patch_width = width // patch_size

        backbone_output = self.backbone(pixel_values, **kwargs)
        intermediate_features = list(zip(backbone_output.feature_maps, backbone_output.cls_tokens))

        head_output = self.head(intermediate_features, patch_height, patch_width)

        predicted_depth = self.features_to_depth(head_output)
        predicted_depth = predicted_depth.squeeze(dim=1)

        return DepthEstimatorOutput(
            loss=loss,
            predicted_depth=predicted_depth,
            hidden_states=backbone_output.hidden_states,
            attentions=backbone_output.attentions,
        )


__all__ = ["CHMv2ForDepthEstimation", "CHMv2PreTrainedModel"]
