#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/slanet/modular_slanet.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_slanet.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from ... import initialization as init
from ...activations import ACT2CLS, ACT2FN
from ...backbone_utils import filter_output_hidden_states, load_backbone
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithNoAttention
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from .configuration_slanet import SLANetConfig


class SLANetPreTrainedModel(PreTrainedModel):
    config: SLANetConfig
    base_model_prefix = "backbone"
    main_input_name = "pixel_values"
    input_modalities = ("image",)
    supports_gradient_checkpointing = True
    _keep_in_fp32_modules_strict = []

    @torch.no_grad()
    def _init_weights(self, module):
        """Initialize the weights"""
        super()._init_weights(module)

        # Initialize GRUCell (replicates PyTorch default reset_parameters)
        if isinstance(module, nn.GRUCell):
            std = 1.0 / math.sqrt(module.hidden_size) if module.hidden_size > 0 else 0
            init.uniform_(module.weight_ih, -std, std)
            init.uniform_(module.weight_hh, -std, std)
            if module.bias_ih is not None:
                init.uniform_(module.bias_ih, -std, std)
            if module.bias_hh is not None:
                init.uniform_(module.bias_hh, -std, std)

        # Initialize SLAHead layers
        if isinstance(module, SLANetSLAHead):
            std = 1.0 / math.sqrt(self.config.hidden_size * 1.0)
            # Initialize structure_generator and loc_generator layers
            for generator in (module.structure_generator,):
                for layer in generator.children():
                    if isinstance(layer, nn.Linear):
                        init.uniform_(layer.weight, -std, std)
                        if layer.bias is not None:
                            init.uniform_(layer.bias, -std, std)


@dataclass
@auto_docstring
class SLANetForTableRecognitionOutput(BaseModelOutputWithNoAttention):
    r"""
    head_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Hidden-states of the SLANetSLAHead at each prediction step, varies up to max `self.config.max_text_length` states (depending on early exits).
    head_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Attentions of the SLANetSLAHead at each prediction step, varies up to max `self.config.max_text_length` attentions (depending on early exits).
    """

    head_hidden_states: torch.FloatTensor | None = None
    head_attentions: torch.FloatTensor | None = None


class SLANetAttentionGRUCell(nn.Module):
    def __init__(self, input_size, hidden_size, num_embeddings):
        super().__init__()

        self.input_to_hidden = nn.Linear(input_size, hidden_size, bias=False)
        self.hidden_to_hidden = nn.Linear(hidden_size, hidden_size)
        self.score = nn.Linear(hidden_size, 1, bias=False)

        self.rnn = nn.GRUCell(input_size + num_embeddings, hidden_size)

    def forward(
        self,
        prev_hidden: torch.FloatTensor,
        batch_hidden: torch.FloatTensor,
        char_onehots: torch.FloatTensor,
        **kwargs: Unpack[TransformersKwargs],
    ):
        batch_hidden_proj = self.input_to_hidden(batch_hidden)
        prev_hidden_proj = self.hidden_to_hidden(prev_hidden).unsqueeze(1)

        attention_scores = batch_hidden_proj + prev_hidden_proj
        attention_scores = torch.tanh(attention_scores)
        attention_scores = self.score(attention_scores)

        attn_weights = F.softmax(attention_scores, dim=1, dtype=torch.float32).to(attention_scores.dtype)
        attn_weights = attn_weights.transpose(1, 2)
        context = torch.matmul(attn_weights, batch_hidden).squeeze(1)
        concat_context = torch.cat([context, char_onehots], 1)
        hidden_states = self.rnn(concat_context, prev_hidden)

        return hidden_states, attn_weights


class SLANetMLP(nn.Module):
    def __init__(self, hidden_size, out_channels, activation=None):
        super().__init__()
        self.fc1 = nn.Linear(hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, out_channels)
        self.act_fn = nn.Identity() if activation is None else ACT2CLS[activation]()

    def forward(self, hidden_states):
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.fc2(hidden_states)
        hidden_states = self.act_fn(hidden_states)
        return hidden_states


class SLANetSLAHead(SLANetPreTrainedModel):
    _can_record_outputs = {
        "attentions": SLANetAttentionGRUCell,
    }

    def __init__(
        self,
        config: dict | None = None,
        **kwargs,
    ):
        super().__init__(config)

        self.structure_attention_cell = SLANetAttentionGRUCell(
            config.post_conv_out_channels, config.hidden_size, config.out_channels
        )
        self.structure_generator = SLANetMLP(config.hidden_size, config.out_channels)

        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    @filter_output_hidden_states
    def forward(
        self,
        hidden_states: torch.FloatTensor,
        targets: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ):
        features = torch.zeros(
            (hidden_states.shape[0], self.config.hidden_size), dtype=torch.float32, device=hidden_states.device
        )
        predicted_chars = torch.zeros(size=[hidden_states.shape[0]], dtype=torch.long, device=hidden_states.device)

        structure_preds_list = []
        structure_ids_list = []
        for _ in range(self.config.max_text_length + 1):
            embedding_feature = F.one_hot(predicted_chars, self.config.out_channels).float()
            features, _ = self.structure_attention_cell(features, hidden_states.float(), embedding_feature)
            structure_step = self.structure_generator(features)
            predicted_chars = structure_step.argmax(dim=1)

            structure_preds_list.append(structure_step)
            structure_ids_list.append(predicted_chars)
            if torch.stack(structure_ids_list, dim=1).eq(self.config.out_channels - 1).any(-1).all():
                break
        structure_preds = F.softmax(torch.stack(structure_preds_list, dim=1), dim=-1, dtype=torch.float32).to(
            hidden_states.dtype
        )

        return BaseModelOutput(last_hidden_state=structure_preds, hidden_states=structure_preds_list)


class SLANetConvLayer(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        activation: str = "hardswish",
        groups: int = 1,
    ):
        super().__init__()
        self.convolution = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=kernel_size // 2,
            bias=False,
            groups=groups,
        )
        self.normalization = nn.BatchNorm2d(out_channels)
        self.activation = ACT2FN[activation] if activation is not None else nn.Identity()

    def forward(self, input: Tensor) -> Tensor:
        hidden_state = self.convolution(input)
        hidden_state = self.normalization(hidden_state)
        hidden_state = self.activation(hidden_state)
        return hidden_state


class SLANetDepthwiseSeparableConvLayer(GradientCheckpointingLayer):
    """
    Depthwise Separable Convolution Layer: Depthwise Conv -> Pointwise Conv
    Core component of lightweight models (e.g., MobileNet, PP-LCNet) that significantly reduces
    the number of parameters and computational cost.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        stride,
        kernel_size,
        config,
    ):
        super().__init__()
        self.depthwise_convolution = SLANetConvLayer(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=kernel_size,
            stride=stride,
            groups=in_channels,
            activation=config.hidden_act,
        )
        self.squeeze_excitation_module = nn.Identity()
        self.pointwise_convolution = SLANetConvLayer(
            in_channels=in_channels,
            kernel_size=1,
            out_channels=out_channels,
            stride=1,
            activation=config.hidden_act,
        )

    def forward(self, hidden_state):
        hidden_state = self.depthwise_convolution(hidden_state)
        hidden_state = self.squeeze_excitation_module(hidden_state)
        hidden_state = self.pointwise_convolution(hidden_state)

        return hidden_state


class SLANetBottleneck(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        activation,
        config,
    ):
        super().__init__()
        self.conv1 = SLANetConvLayer(
            in_channels=in_channels, out_channels=out_channels, kernel_size=1, activation=activation
        )
        self.conv2 = SLANetDepthwiseSeparableConvLayer(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=1,
            config=config,
        )

    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.conv2(hidden_states)

        return hidden_states


class SLANetCSPLayer(nn.Module):
    """
    Cross Stage Partial (CSP) network layer. Similar in structure to DFineCSPRepLayer, but with a different forward computation.
    """

    def __init__(
        self,
        config,
        in_channels,
        out_channels,
        kernel_size=3,
        expansion=0.5,
        num_blocks=1,
        activation="hardswish",
    ):
        super().__init__()
        hidden_channels = int(out_channels * expansion)
        self.conv1 = SLANetConvLayer(in_channels, hidden_channels, 1, activation=activation)
        self.conv2 = SLANetConvLayer(in_channels, hidden_channels, 1, activation=activation)
        self.conv3 = SLANetConvLayer(2 * hidden_channels, out_channels, 1, activation=activation)
        self.bottlenecks = nn.ModuleList(
            [
                SLANetBottleneck(hidden_channels, hidden_channels, kernel_size, activation, config)
                for _ in range(num_blocks)
            ]
        )

    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
        residual = self.conv1(hidden_states)

        hidden_states = self.conv2(hidden_states)
        for bottleneck in self.bottlenecks:
            hidden_states = bottleneck(hidden_states)

        hidden_states = torch.cat((hidden_states, residual), dim=1)
        hidden_states = self.conv3(hidden_states)

        return hidden_states


class SLANetCSPPAN(nn.Module):
    """
    CSP-PAN: Path Aggregation Network with CSP layers
    """

    def __init__(
        self,
        config,
        in_channel_list,
    ):
        super().__init__()
        out_channels = config.post_conv_out_channels
        activation = config.hidden_act
        kernel_size = config.csp_kernel_size
        csp_num_blocks = config.csp_num_blocks

        self.channel_projector = nn.ModuleList(
            [
                SLANetConvLayer(
                    in_channels=in_channel_list[i], out_channels=out_channels, kernel_size=1, activation=activation
                )
                for i in range(len(in_channel_list))
            ]
        )

        # build top-down blocks
        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
        self.top_down_blocks = nn.ModuleList(
            [
                SLANetCSPLayer(
                    config,
                    out_channels * 2,
                    out_channels,
                    kernel_size=kernel_size,
                    num_blocks=csp_num_blocks,
                    activation=activation,
                )
                for _ in range(len(in_channel_list) - 1, 0, -1)
            ]
        )

        # build bottom-up blocks
        self.downsamples = nn.ModuleList(
            [
                SLANetDepthwiseSeparableConvLayer(
                    out_channels,
                    out_channels,
                    kernel_size=kernel_size,
                    stride=2,
                    config=config,
                )
                for _ in range(len(in_channel_list) - 1)
            ]
        )
        self.bottom_up_blocks = nn.ModuleList(
            [
                SLANetCSPLayer(
                    config,
                    out_channels * 2,
                    out_channels,
                    kernel_size=kernel_size,
                    num_blocks=csp_num_blocks,
                    activation=activation,
                )
                for _ in range(len(in_channel_list) - 1)
            ]
        )

    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
        projected_features = []
        for idx in range(len(self.channel_projector)):
            projected_features.append(self.channel_projector[idx](hidden_states[idx]))

        top_down_features = [projected_features[-1]]
        for top_down_block, low_level_feature in zip(self.top_down_blocks, reversed(projected_features[:-1])):
            high_level_feature = top_down_features[-1]
            upsampled_feature = F.interpolate(
                high_level_feature,
                size=low_level_feature.shape[-2:],
                mode="nearest",
            )
            fused_feature = top_down_block(torch.cat([upsampled_feature, low_level_feature], dim=1))
            top_down_features.append(fused_feature)

        pyramid_features = list(reversed(top_down_features))
        output_feature = pyramid_features[0]
        for downsample_layer, bottom_up_block, high_level_feature in zip(
            self.downsamples, self.bottom_up_blocks, pyramid_features[1:]
        ):
            downsampled_feature = downsample_layer(output_feature)
            output_feature = bottom_up_block(torch.cat([downsampled_feature, high_level_feature], dim=1))

        hidden_states = output_feature.flatten(2).transpose(1, 2)
        return hidden_states


class SLANetBackbone(SLANetPreTrainedModel):
    def __init__(self, config: SLANetConfig):
        super().__init__(config)
        self.vision_backbone = load_backbone(config)
        self.post_csp_pan = SLANetCSPPAN(config, self.vision_backbone.num_features[2:])

        self.post_init()

    @can_return_tuple
    @auto_docstring
    def forward(
        self, hidden_states: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
    ) -> tuple[torch.FloatTensor] | BaseModelOutputWithNoAttention:
        outputs = self.vision_backbone(hidden_states, **kwargs)
        hidden_states = self.post_csp_pan(outputs.feature_maps)
        return BaseModelOutputWithNoAttention(
            last_hidden_state=hidden_states,
            hidden_states=outputs.hidden_states,
        )


@auto_docstring(
    custom_intro="""
    SLANet Table Recognition model for table recognition tasks. Wraps the core SLANetPreTrainedModel
    and returns outputs compatible with the Transformers table recognition API.
    """
)
class SLANetForTableRecognition(SLANetPreTrainedModel):
    _keys_to_ignore_on_load_missing = ["num_batches_tracked"]

    def __init__(self, config: SLANetConfig):
        super().__init__(config)
        self.backbone = SLANetBackbone(config=config)
        self.head = SLANetSLAHead(config=config)
        self.post_init()

    @can_return_tuple
    @auto_docstring
    def forward(
        self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
    ) -> tuple[torch.FloatTensor] | SLANetForTableRecognitionOutput:
        outputs = self.backbone(pixel_values, **kwargs)
        head_outputs = self.head(outputs.last_hidden_state, **kwargs)
        # Key difference: no attentions in its vision model
        return SLANetForTableRecognitionOutput(
            last_hidden_state=head_outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
            head_hidden_states=head_outputs.hidden_states,
            head_attentions=head_outputs.attentions,
        )


__all__ = ["SLANetForTableRecognition", "SLANetPreTrainedModel", "SLANetSLAHead", "SLANetBackbone"]
