from __future__ import annotations

import logging
from collections.abc import Callable
from typing import Any

import torch
from torch import nn
from transformers import (
    EvalPrediction,
    FeatureExtractionMixin,
    PreTrainedTokenizerBase,
    ProcessorMixin,
    TrainerCallback,
)
from transformers.image_processing_utils import BaseImageProcessor

from sentence_transformers.base.evaluation import BaseEvaluator
from sentence_transformers.base.trainer import BaseTrainer
from sentence_transformers.sparse_encoder.callbacks.splade_callbacks import SpladeRegularizerWeightSchedulerCallback
from sentence_transformers.sparse_encoder.data_collator import SparseEncoderDataCollator
from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss
from sentence_transformers.sparse_encoder.model import SparseEncoder
from sentence_transformers.sparse_encoder.model_card import SparseEncoderModelCardCallback, SparseEncoderModelCardData
from sentence_transformers.sparse_encoder.training_args import SparseEncoderTrainingArguments
from sentence_transformers.util import is_datasets_available
from sentence_transformers.util.decorators import deprecated_kwargs

if is_datasets_available():
    from datasets import Dataset, DatasetDict, IterableDataset

logger = logging.getLogger(__name__)


class SparseEncoderTrainer(BaseTrainer):
    """
    SparseEncoderTrainer is a simple but feature-complete training and eval loop for PyTorch
    based on the SentenceTransformerTrainer that based on 🤗 Transformers :class:`~transformers.Trainer`.

    This trainer integrates support for various :class:`transformers.TrainerCallback` subclasses, such as:

    - :class:`~transformers.integrations.WandbCallback` to automatically log training metrics to W&B if `wandb` is installed
    - :class:`~transformers.integrations.TensorBoardCallback` to log training metrics to TensorBoard if `tensorboard` is accessible.
    - :class:`~transformers.integrations.CodeCarbonCallback` to track the carbon emissions of your model during training if `codecarbon` is installed.

        - Note: These carbon emissions will be included in your automatically generated model card.

    See the Transformers `Callbacks <https://huggingface.co/docs/transformers/main/en/main_classes/callback>`_
    documentation for more information on the integrated callbacks and how to write your own callbacks.

    Args:
        model (:class:`~sentence_transformers.sparse_encoder.model.SparseEncoder`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
        args (:class:`~sentence_transformers.sparse_encoder.training_args.SparseEncoderTrainingArguments`, *optional*):
            The arguments to tweak for training. Will default to a basic instance of
            :class:`~sentence_transformers.sparse_encoder.training_args.SparseEncoderTrainingArguments` with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
        train_dataset (Union[:class:`datasets.Dataset`, :class:`datasets.DatasetDict`, :class:`datasets.IterableDataset`, Dict[str, :class:`datasets.Dataset`]], *optional*):
            The dataset to use for training. Must have a format accepted by your loss function, see
            `Training Overview > Dataset Format <../../../docs/sparse_encoder/training_overview.html#dataset-format>`_.
        eval_dataset (Union[:class:`datasets.Dataset`, :class:`datasets.DatasetDict`, :class:`datasets.IterableDataset`, Dict[str, :class:`datasets.Dataset`]], *optional*):
            The dataset to use for evaluation. Must have a format accepted by your loss function, see
            `Training Overview > Dataset Format <../../../docs/sparse_encoder/training_overview.html#dataset-format>`_.
        loss (Optional[Union[:class:`torch.nn.Module`, Dict[str, :class:`torch.nn.Module`],\
            Callable[[:class:`~sentence_transformers.sparse_encoder.model.SparseEncoder`], :class:`torch.nn.Module`],\
            Dict[str, Callable[[:class:`~sentence_transformers.sparse_encoder.model.SparseEncoder`]]]], *optional*):
            The loss function to use for training. Can either be a loss class instance, a dictionary mapping
            dataset names to loss class instances, a function that returns a loss class instance given a model,
            or a dictionary mapping dataset names to functions that return a loss class instance given a model.
            In practice, the latter two are primarily used for hyper-parameter optimization. Will default to
            :class:`~sentence_transformers.sparse_encoder.losses.SpladeLoss` with :class:`~sentence_transformers.sparse_encoder.losses.SparseMultipleNegativesRankingLoss` if no ``loss`` is provided.
        evaluator (Union[:class:`~sentence_transformers.base.evaluation.BaseEvaluator`,\
            List[:class:`~sentence_transformers.base.evaluation.BaseEvaluator`]], *optional*):
            The evaluator instance for useful evaluation metrics during training. You can use an ``evaluator`` with
            or without an ``eval_dataset``, and vice versa. Generally, the metrics that an ``evaluator`` returns
            are more useful than the loss value returned from the ``eval_dataset``. A list of evaluators will be
            wrapped in a :class:`~sentence_transformers.base.evaluation.SequentialEvaluator` to run them sequentially.
        callbacks (List of [:class:`transformers.TrainerCallback`], *optional*):
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
            detailed in [here](callback).

            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[:class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler.LambdaLR`]`, *optional*, defaults to `(None, None)`):
            A tuple containing the optimizer and the scheduler to use. Will default to an instance of :class:`torch.optim.AdamW`
            on your model and a scheduler given by :func:`transformers.get_linear_schedule_with_warmup` controlled by `args`.

    Important attributes:

        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)
    """

    model_class = SparseEncoder
    model_card_data_class = SparseEncoderModelCardData
    model_card_callback_class = SparseEncoderModelCardCallback
    data_collator_class = SparseEncoderDataCollator
    training_args_class = SparseEncoderTrainingArguments

    @deprecated_kwargs(tokenizer="processing_class")
    def __init__(
        self,
        model: SparseEncoder | None = None,
        args: SparseEncoderTrainingArguments | None = None,
        train_dataset: Dataset | DatasetDict | IterableDataset | dict[str, Dataset] | None = None,
        eval_dataset: Dataset | DatasetDict | IterableDataset | dict[str, Dataset] | None = None,
        loss: nn.Module
        | dict[str, nn.Module]
        | Callable[[SparseEncoder], torch.nn.Module]
        | dict[str, Callable[[SparseEncoder], torch.nn.Module]]
        | None = None,
        evaluator: BaseEvaluator | list[BaseEvaluator] | None = None,
        data_collator: SparseEncoderDataCollator | None = None,
        processing_class: PreTrainedTokenizerBase
        | BaseImageProcessor
        | FeatureExtractionMixin
        | ProcessorMixin
        | None = None,
        model_init: Callable[[], SparseEncoder] | None = None,
        compute_metrics: Callable[[EvalPrediction], dict] | None = None,
        callbacks: list[TrainerCallback] | None = None,
        optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None,
        preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
    ) -> None:
        super().__init__(
            model=model,
            args=args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            loss=loss,
            evaluator=evaluator,
            data_collator=data_collator,
            processing_class=processing_class,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        )
        self.model: SparseEncoder
        self.args: SparseEncoderTrainingArguments
        self.data_collator: SparseEncoderDataCollator

    def get_default_loss(self, model: SparseEncoder) -> torch.nn.Module:
        logger.info(
            "No `loss` passed, using `SpladeLoss` with `SparseMultipleNegativesRankingLoss` as the default. "
            "Note: `query_regularizer_weight` and `document_regularizer_weight` are sensitive parameters "
            "and should be tuned for your task."
        )
        return SpladeLoss(
            model=model,
            loss=SparseMultipleNegativesRankingLoss(model=model),
            query_regularizer_weight=5e-5,  # Weight for query loss
            document_regularizer_weight=3e-5,  # Weight for document loss
        )

    def prepare_loss(
        self,
        loss: Callable[[SparseEncoder], torch.nn.Module] | torch.nn.Module,
        model: SparseEncoder,
    ) -> torch.nn.Module:
        loss = super().prepare_loss(loss, model)

        is_splade_loss = isinstance(loss, SpladeLoss) if loss is not None else False
        splade_scheduler_callback_index = None
        for idx, callback in enumerate(self.callback_handler.callbacks):
            if isinstance(callback, SpladeRegularizerWeightSchedulerCallback):
                splade_scheduler_callback_index = idx
                break

        # If we're using SpladeLoss but don't have a scheduler callback, add one or if it's not the second one in the list
        if is_splade_loss and (splade_scheduler_callback_index is None or splade_scheduler_callback_index > 1):
            if splade_scheduler_callback_index is not None:
                splade_callback = self.callback_handler.callbacks.pop(splade_scheduler_callback_index)

            else:
                logger.warning(
                    "SpladeLoss detected without SpladeRegularizerWeightSchedulerCallback. "
                    "Adding default SpladeRegularizerWeightSchedulerCallback to gradually increase weight values from 0 to their maximum."
                )

                # Create and insert the callback after the default callback informing the trainer when to log, evaluate, save, etc.
                splade_callback = SpladeRegularizerWeightSchedulerCallback(loss=loss)
            self.callback_handler.callbacks.insert(1, splade_callback)

        return loss
