from __future__ import annotations

import inspect
import logging
import os
import re
from abc import ABC, abstractmethod
from collections import OrderedDict
from collections.abc import Callable
from contextlib import nullcontext
from functools import partial
from typing import Any

import torch
from torch import nn
from torch.utils.data import BatchSampler, ConcatDataset, DataLoader, RandomSampler
from transformers import EvalPrediction, PreTrainedTokenizerBase, Trainer, TrainerCallback
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.image_processing_utils import BaseImageProcessor
from transformers.integrations import WandbCallback
from transformers.processing_utils import ProcessorMixin
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import EvalLoopOutput

from sentence_transformers.base.data_collator import BaseDataCollator
from sentence_transformers.base.evaluation import BaseEvaluator, SequentialEvaluator
from sentence_transformers.base.model import BaseModel
from sentence_transformers.base.model_card import BaseModelCardCallback, BaseModelCardData
from sentence_transformers.base.modules import Router
from sentence_transformers.base.sampler import (
    DefaultBatchSampler,
    GroupByLabelBatchSampler,
    MultiDatasetDefaultBatchSampler,
    NoDuplicatesBatchSampler,
    ProportionalBatchSampler,
    RoundRobinBatchSampler,
)
from sentence_transformers.base.training_args import BaseTrainingArguments, BatchSamplers, MultiDatasetBatchSamplers
from sentence_transformers.util import disable_logging, fullname, is_datasets_available, is_training_available
from sentence_transformers.util.decorators import deprecated_kwargs

if is_datasets_available():
    from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, Value

logger = logging.getLogger(__name__)

# The TrackioCallback is only available in the v4.54+ of transformers, but I'd like to keep Sentence Transformers
# compatible with older versions of transformers as well, so we import it conditionally
try:
    from transformers.integrations import TrackioCallback
except ImportError:
    TrackioCallback = None


class BaseTrainer(Trainer, ABC):
    """
    BaseTrainer is a simple but feature-complete training and eval loop for PyTorch
    based on the 🤗 Transformers :class:`~transformers.Trainer`.

    This trainer integrates support for various :class:`transformers.TrainerCallback` subclasses, such as:

    - :class:`~transformers.integrations.WandbCallback` to automatically log training metrics to W&B if `wandb` is installed
    - :class:`~transformers.integrations.TensorBoardCallback` to log training metrics to TensorBoard if `tensorboard` is accessible.
    - :class:`~transformers.integrations.CodeCarbonCallback` to track the carbon emissions of your model during training if `codecarbon` is installed.

        - Note: These carbon emissions will be included in your automatically generated model card.

    See the Transformers `Callbacks <https://huggingface.co/docs/transformers/main/en/main_classes/callback>`_
    documentation for more information on the integrated callbacks and how to write your own callbacks.

    Args:
        model (:class:`~sentence_transformers.base.model.BaseModel`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
        args (:class:`~sentence_transformers.base.training_args.BaseTrainingArguments`, *optional*):
            The arguments to tweak for training. Will default to a basic instance of
            :class:`~sentence_transformers.base.training_args.BaseTrainingArguments` with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
        train_dataset (Union[:class:`datasets.Dataset`, :class:`datasets.DatasetDict`, :class:`datasets.IterableDataset`, Dict[str, :class:`datasets.Dataset`]], *optional*):
            The dataset to use for training. Must have a format accepted by your loss function.
        eval_dataset (Union[:class:`datasets.Dataset`, :class:`datasets.DatasetDict`, :class:`datasets.IterableDataset`, Dict[str, :class:`datasets.Dataset`]], *optional*):
            The dataset to use for evaluation. Must have a format accepted by your loss function.
        loss (Optional[Union[:class:`torch.nn.Module`, Dict[str, :class:`torch.nn.Module`],\
            Callable[[:class:`~sentence_transformers.base.model.BaseModel`], :class:`torch.nn.Module`],\
            Dict[str, Callable[[:class:`~sentence_transformers.base.model.BaseModel`]]]], *optional*):
            The loss function to use for training. Can either be a loss class instance, a dictionary mapping
            dataset names to loss class instances, a function that returns a loss class instance given a model,
            or a dictionary mapping dataset names to functions that return a loss class instance given a model.
            In practice, the latter two are primarily used for hyper-parameter optimization.
        evaluator (Union[:class:`~sentence_transformers.base.evaluation.BaseEvaluator`,\
            List[:class:`~sentence_transformers.base.evaluation.BaseEvaluator`]], *optional*):
            The evaluator instance for useful evaluation metrics during training. You can use an ``evaluator`` with
            or without an ``eval_dataset``, and vice versa. Generally, the metrics that an ``evaluator`` returns
            are more useful than the loss value returned from the ``eval_dataset``. A list of evaluators will be
            wrapped in a :class:`~sentence_transformers.base.evaluation.SequentialEvaluator` to run them sequentially.
        callbacks (List of [:class:`transformers.TrainerCallback`], *optional*):
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
            detailed in [here](callback).

            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[:class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler.LambdaLR`]`, *optional*, defaults to `(None, None)`):
            A tuple containing the optimizer and the scheduler to use. Will default to an instance of :class:`torch.optim.AdamW`
            on your model and a scheduler given by :func:`transformers.get_linear_schedule_with_warmup` controlled by `args`.

    Important attributes:

        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)

    """

    model_class = BaseModel
    model_card_data_class = BaseModelCardData
    model_card_callback_class = BaseModelCardCallback
    data_collator_class = BaseDataCollator
    training_args_class = BaseTrainingArguments

    @deprecated_kwargs(tokenizer="processing_class")
    def __init__(
        self,
        model: BaseModel | None = None,
        args: BaseTrainingArguments | None = None,
        train_dataset: Dataset | DatasetDict | IterableDataset | dict[str, Dataset] | None = None,
        eval_dataset: Dataset | DatasetDict | IterableDataset | dict[str, Dataset] | None = None,
        loss: nn.Module
        | dict[str, nn.Module]
        | Callable[[BaseModel], torch.nn.Module]
        | dict[str, Callable[[BaseModel], torch.nn.Module]]
        | None = None,
        evaluator: BaseEvaluator | list[BaseEvaluator] | None = None,
        data_collator: BaseDataCollator | None = None,
        processing_class: PreTrainedTokenizerBase
        | BaseImageProcessor
        | FeatureExtractionMixin
        | ProcessorMixin
        | None = None,
        model_init: Callable[[], BaseModel] | None = None,
        compute_metrics: Callable[[EvalPrediction], dict] | None = None,
        callbacks: list[TrainerCallback] | None = None,
        optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None,
        preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
    ) -> None:
        if not is_training_available():
            raise RuntimeError(
                f"To train a {self.model_class.__name__} model, you need to install the `accelerate` and `datasets` modules. "
                "You can do so with the `train` extra:\n"
                'pip install -U "sentence-transformers[train]"'
            )

        if args is None:
            output_dir = "tmp_trainer"
            logger.info(f"No `args` passed, using `{self.training_args_class.__name__}(output_dir={output_dir})`.")
            args = self.training_args_class(output_dir=output_dir)
        elif not isinstance(args, self.training_args_class):
            raise ValueError(
                f"Please pass an instance of `{fullname(self.training_args_class)}` as the `args` argument."
            )

        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError(f"`{self.__class__.__name__}` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                logger.warning(
                    f"`{self.__class__.__name__}` requires either a `model` or `model_init` argument, but not both. "
                    "`model_init` will overwrite your model when calling the `train` method."
                )
            self.model_init = model_init

        if compute_metrics is not None:
            logger.warning(
                f"`compute_metrics` is currently not compatible with the {self.__class__.__name__}. Please use the "
                "`evaluator` argument instead for detailed evaluation metrics, or the `eval_dataset` argument for "
                "the evaluation loss."
            )

        # Get a dictionary of the default training arguments, so we can determine which arguments have been changed
        # for the model card
        default_args_dict = self.training_args_class(
            output_dir="unused", accelerator_config={"use_configured_state": True}
        ).to_dict()

        # If the model ID is set via the ...TrainingArguments, but not via the ...sModelCardData,
        # then we can set it here for the model card regardless
        if args.hub_model_id and not model.model_card_data.model_id:
            model.model_card_data.set_model_id(args.hub_model_id)

        if (
            processing_class is None
            and hasattr(model, "processor")
            and isinstance(
                model.processor, (PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin)
            )
        ):
            processing_class = model.processor

        if data_collator is None:
            data_collator = self.get_data_collator(model=model, args=args, processing_class=processing_class)

        for dataset_name, dataset in zip(["train", "eval"], [train_dataset, eval_dataset]):
            if isinstance(dataset, IterableDataset) and dataset.column_names is None:
                sample = next(iter(dataset))
                naive_type_mapping = {str: "string", int: "int64", float: "float32", bool: "bool"}
                example_features = {
                    key: Value(naive_type_mapping.get(type(value), "null")) for key, value in sample.items()
                }
                raise ValueError(
                    f"The provided `{dataset_name}_dataset` must have Features. Specify them with e.g.:\n"
                    f"{dataset_name}_dataset = {dataset_name}_dataset.cast(Features({example_features}))\n"
                    "or by providing the Features to the IterableDataset initialization method. See the Datasets "
                    "documentation for more information on dataset Features: "
                    "https://huggingface.co/docs/datasets/en/about_dataset_features"
                )

        if isinstance(train_dataset, dict) and not isinstance(train_dataset, DatasetDict):
            train_dataset = DatasetDict(train_dataset)
        if isinstance(eval_dataset, dict) and not isinstance(eval_dataset, DatasetDict):
            eval_dataset = DatasetDict(eval_dataset)

        # super.__init__() will still raise a ValueError if `eval_dataset` is None, `evaluator` is None,
        # while eval_strategy is not "no", so let's get ahead of it with a more useful ST-specific error message
        if eval_dataset is None and evaluator is None and args.eval_strategy != "no":
            raise ValueError(
                f"You have set `args.eval_strategy` to {args.eval_strategy}, but you didn't provide an `eval_dataset` or an `evaluator`. "
                f"Either provide an `eval_dataset` or an `evaluator` to `{self.__class__.__name__}`, "
                "or set `args.eval_strategy='no'` to skip evaluation."
            )

        super().__init__(
            model=None if self.model_init else model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset if eval_dataset is not None or evaluator is None else "dummy",
            processing_class=processing_class,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        )
        # If the eval_dataset is "dummy", then we set it back to None
        if self.eval_dataset == "dummy":
            self.eval_dataset = None

        # If losses return dictionaries, then we want to be able to accumulate the loss components
        # before merging them into a single loss (required by the base Trainer)
        self.accum_loss_components = {"train": {}, "eval": {}}

        # Every Sentence Transformer model can always return a loss, so we set this to True
        # to avoid having to specify it in the data collator or model's forward
        self.can_return_loss = True

        self.model: BaseModel
        self.args: BaseTrainingArguments
        self.data_collator: BaseDataCollator

        # Set the W&B or Trackio project via environment variables if it's not already set
        if any(isinstance(callback, WandbCallback) for callback in self.callback_handler.callbacks):
            os.environ.setdefault("WANDB_PROJECT", "sentence-transformers")
        if TrackioCallback is not None and any(
            isinstance(callback, TrackioCallback) for callback in self.callback_handler.callbacks
        ):
            os.environ.setdefault("TRACKIO_PROJECT", "sentence-transformers")

        if loss is None:
            loss = self.get_default_loss(self.model)

        if isinstance(loss, dict):
            self.loss = {dataset_name: self.prepare_loss(loss_fn, model) for dataset_name, loss_fn in loss.items()}
            for dataset_name, dataset in zip(["train", "eval"], [train_dataset, eval_dataset]):
                if dataset is None:
                    continue
                if not isinstance(dataset, dict):
                    raise ValueError(
                        f"If the provided `loss` is a dict, then the `{dataset_name}_dataset` must be a `DatasetDict`."
                    )
                if missing := set(dataset.keys()) - set(loss.keys()):
                    raise ValueError(
                        f"If the provided `loss` is a dict, then all keys from the `{dataset_name}_dataset` dictionary must occur in `loss` also. "
                        f"Currently, {sorted(missing)} occur{'s' if len(missing) == 1 else ''} in `{dataset_name}_dataset` but not in `loss`."
                    )
        else:
            self.loss = self.prepare_loss(loss, model)

        # If evaluator is a list, we wrap it in a SequentialEvaluator
        if evaluator is not None and not isinstance(evaluator, BaseEvaluator):
            evaluator = SequentialEvaluator(evaluator)
        self.evaluator = evaluator

        if self.train_dataset is not None:
            self.train_dataset = self.preprocess_dataset(train_dataset, dataset_name="train")
        if self.eval_dataset is not None:
            self.eval_dataset = self.preprocess_dataset(eval_dataset, dataset_name="eval")
        self.add_model_card_callback(default_args_dict)

    def get_data_collator(
        self,
        model: BaseModel,
        args: BaseTrainingArguments,
        processing_class: PreTrainedTokenizerBase
        | BaseImageProcessor
        | FeatureExtractionMixin
        | ProcessorMixin
        | None = None,
    ) -> BaseDataCollator:
        """
        Load the data collator for the trainer.

        Args:
            model (:class:`~sentence_transformers.base.model.BaseModel`):
                The model to train, evaluate or use for predictions.
            args (:class:`~sentence_transformers.base.training_args.BaseTrainingArguments`):
                The arguments to tweak for training.
            processing_class (Union[:class:`transformers.PreTrainedTokenizerBase`, :class:`transformers.BaseImageProcessor`, :class:`transformers.FeatureExtractionMixin`, :class:`transformers.ProcessorMixin`], *optional*):
                The processing class to use for tokenization or image processing.
        Returns:
            :class:`BaseDataCollator`: The data collator to use for the trainer

        .. note::

            This method can be overridden by subclassing the trainer to use a custom data collator.
        """
        if Router in [module.__class__ for module in model.children()] and not args.router_mapping:
            raise ValueError(
                "You are using a Router module in your model, but you did not provide a `router_mapping` in the "
                "training arguments. This means that the Router module will not be able to route the inputs to "
                "the correct submodules. Please provide a `router_mapping` that maps column names to routes, "
                "e.g. {'column_one': 'query', 'column_two': 'document', 'column_three': 'document'}."
            )

        return self.data_collator_class(
            preprocess_fn=model.preprocess,
            router_mapping=args.router_mapping,
            prompts=args.prompts,
        )

    def add_model_card_callback(self, default_args_dict: dict[str, Any]) -> None:
        """
        Add a callback responsible for automatically tracking data required for the automatic model card generation

        This method is called in the ``__init__`` method of the trainer subclass.

        Args:
            default_args_dict (Dict[str, Any]): A dictionary of the default training arguments, so we can determine
                which arguments have been changed for the model card.

        .. note::

            This method can be overridden by subclassing the trainer to remove/customize this callback in custom uses cases
        """

        model_card_callback = self.model_card_callback_class(default_args_dict)
        self.add_callback(model_card_callback)
        model_card_callback.on_init_end(self.args, self.state, self.control, model=self.model, trainer=self)

    def call_model_init(self, trial=None) -> BaseModel:
        model = super().call_model_init(trial=trial)
        # If the Trainer already has a loss, then we'll want to override the model in the loss function
        if not hasattr(self, "loss"):
            return model

        # Multi-loss training:
        if isinstance(self.loss, dict):
            for key, loss_fn in self.loss.items():
                # If a loss function is not yet initialized, we initialize it here
                if not isinstance(loss_fn, torch.nn.Module):
                    self.loss[key] = loss_fn(model)
                # Otherwise, we override the original model with the updated model in the loss function
                elif hasattr(loss_fn, "model"):
                    self.loss[key] = self.override_model_in_loss(loss_fn, model)

        # Loss is a function accepting a model as an argument
        elif not isinstance(self.loss, torch.nn.Module):
            self.loss = self.loss(model)

        # Loss is an initialized torch.nn.Module
        elif hasattr(self.loss, "model"):
            self.loss = self.override_model_in_loss(self.loss, model)
        return model

    def override_model_in_loss(self, loss: torch.nn.Module, model: BaseModel) -> torch.nn.Module:
        from sentence_transformers.base.model import BaseModel

        for name, child in loss.named_children():
            if name == "model" and isinstance(child, BaseModel):
                loss.model = model
            elif isinstance(child, torch.nn.Module):
                setattr(loss, name, self.override_model_in_loss(child, model))
        return loss

    def prepare_loss(
        self,
        loss: Callable[[BaseModel], torch.nn.Module] | torch.nn.Module,
        model: BaseModel,
    ) -> torch.nn.Module:
        if isinstance(loss, torch.nn.Module):
            loss = loss.to(model.device)
        else:
            loss = loss(model).to(model.device)

        # Enable per-sample media counting in Transformer.preprocess for losses that minibatch VLM inputs
        if getattr(loss, "requires_media_counts", False):
            if isinstance(model[0], Router):
                input_modules = [route[0] for route in model[0].sub_modules.values()]  # type: ignore[index]
            else:
                input_modules = [model[0]]
            for module in input_modules:
                if hasattr(module, "track_media_counts"):
                    module.track_media_counts = True

        return loss

    @abstractmethod
    def get_default_loss(self, model: BaseModel) -> torch.nn.Module:
        pass

    def compute_loss(
        self,
        model: BaseModel,
        inputs: dict[str, torch.Tensor | Any],
        return_outputs: bool = False,
        num_items_in_batch=None,
    ) -> torch.Tensor | tuple[torch.Tensor, dict[str, Any]]:
        """
        Computes the loss for the BaseModel model.

        It uses ``self.loss`` to compute the loss, which can be a single loss function or a dictionary of loss functions
        for different datasets. If the loss is a dictionary, the dataset name is expected to be passed in the inputs
        under the key "dataset_name". This is done automatically in the ``add_dataset_name_column`` method.
        Note that even if ``return_outputs = True``, the outputs will be empty, as the BaseModel losses do not
        return outputs.

        Args:
            model (BaseModel): The BaseModel model.
            inputs (Dict[str, Union[torch.Tensor, Any]]): The input data for the model.
            return_outputs (bool, optional): Whether to return the outputs along with the loss. Defaults to False.
            num_items_in_batch (int, optional): The number of items in the batch. Defaults to None. Unused, but required by the transformers Trainer.

        Returns:
            Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]: The computed loss. If `return_outputs` is True, returns a tuple of loss and outputs. Otherwise, returns only the loss.
        """
        dataset_name = inputs.pop("dataset_name", None)
        features, labels = self.collect_features(inputs)
        loss_fn = self.loss

        if isinstance(loss_fn, dict) and dataset_name:
            loss_fn = loss_fn[dataset_name]

        # Insert the wrapped (e.g. distributed or compiled) model into the loss function,
        # if the loss stores the model. Only called once per process
        if (
            model == self.model_wrapped
            and hasattr(loss_fn, "model")  # Only if the loss stores the model
            and loss_fn.model != model  # Only if the wrapped model is not already stored
        ):
            loss_fn = self.override_model_in_loss(loss_fn, model)
        loss = loss_fn(features, labels)
        if isinstance(loss, dict):
            self.track_loss_components(loss)
            loss = torch.stack(list(loss.values())).sum()
        if return_outputs:
            # During prediction/evaluation, `compute_loss` will be called with `return_outputs=True`.
            # However, Sentence Transformer losses do not return outputs, so we return an empty dictionary.
            # This does not result in any problems, as the BaseTrainingArguments sets
            # `prediction_loss_only=True` which means that the output is not used.
            return loss, {}
        return loss

    def track_loss_components(self, loss: dict[str, torch.Tensor]) -> None:
        training_type = "train" if self.model.training else "eval"
        for key, value in loss.items():
            # if loss is nan or inf simply add the average of previous logged losses
            if self.args.logging_nan_inf_filter and (torch.isnan(value) or torch.isinf(value)):
                if key not in self.accum_loss_components[training_type]:
                    value = torch.tensor(0.0, dtype=value.dtype, device=value.device)
                else:
                    value = self.accum_loss_components[training_type][key] / (
                        1 + self.state.global_step - self._globalstep_last_logged
                    )

            if key not in self.accum_loss_components[training_type]:
                self.accum_loss_components[training_type][key] = value
            else:
                self.accum_loss_components[training_type][key] = self.accum_loss_components[training_type][key] + value

        if "steps" not in self.accum_loss_components[training_type]:
            self.accum_loss_components[training_type]["steps"] = torch.tensor(0, dtype=int, device=value.device)
        self.accum_loss_components[training_type]["steps"] += 1

    def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
        training_type = None
        if "loss" in logs:
            training_type = "train"
        elif "eval_loss" in logs:
            training_type = "eval"

        if training_type:
            # If we don't copy the logs, we'll include the loss components in the on_evaluate as well,
            # whereas we prefer to have them only in the on_log
            logs = logs.copy()
            # Transformers v4/v5 compatibility: v5.2 moves _nested_gather to `transformers.trainer_pt_utils`,
            # see https://github.com/huggingface/transformers/pull/43744
            if hasattr(self, "_nested_gather"):
                accum_losses = self._nested_gather(self.accum_loss_components[training_type])
            else:
                from transformers.trainer_pt_utils import nested_gather

                accum_losses = nested_gather(
                    self.accum_loss_components[training_type], parallel_mode=self.args.parallel_mode
                )
            if "steps" in accum_losses:
                steps = accum_losses.get("steps").sum().item()
                self.accum_loss_components[training_type]["steps"] *= 0

                for key, value in accum_losses.items():
                    if key == "steps":
                        continue
                    log_key = f"{training_type}_{key}" if training_type == "eval" else key
                    logs[log_key] = round((value.sum() / steps).item(), 4)
                    self.accum_loss_components[training_type][key] = torch.tensor(
                        0.0, dtype=value.dtype, device=value.device
                    )

        # The 'start_time' argument was added in transformers v4.47.0, before which the super().log() method
        # would not accept it. If None, we just call the super().log() method without it so that it works with all versions.
        if start_time is not None:
            return super().log(logs, start_time)
        else:
            return super().log(logs)

    def collect_features(
        self, inputs: dict[str, torch.Tensor | Any]
    ) -> tuple[list[dict[str, torch.Tensor]], torch.Tensor | None]:
        """Turn the inputs from the dataloader into the separate model inputs & the labels.

        Example::

            >>> list(inputs.keys())
            ['return_loss', 'label', 'sentence_0_input_ids', 'sentence_0_token_type_ids', 'sentence_0_attention_mask', 'sentence_1_input_ids', 'sentence_1_token_type_ids', 'sentence_1_attention_mask']
            >>> features, labels = self.collect_features(inputs)
            >>> len(features)
            2
            >>> list(features[0].keys())
            ['input_ids', 'token_type_ids', 'attention_mask']
            >>> list(features[1].keys())
            ['input_ids', 'token_type_ids', 'attention_mask']
            >>> torch.equal(labels, inputs["label"])
            True
        """
        # All inputs ending with one of these suffixes are considered to correspond to a feature
        feature_suffixes = (
            "input_ids",  # text (Transformers)
            "sentence_embedding",  # BoW
            "pixel_values",  # image (CLIPModel, etc.)
            "input_features",  # audio (Whisper, etc.)
            "input_values",  # audio (Wav2Vec2, HuBERT, etc.)
            "pixel_values_videos",  # video
        )
        features = []
        seen_prefixes = set()
        for column in inputs:
            prefix = None
            for suffix in feature_suffixes:
                if column.endswith("_" + suffix):
                    prefix = column[: -len(suffix)]
                    break
            if prefix is None or prefix in seen_prefixes:
                continue
            seen_prefixes.add(prefix)
            features.append({key[len(prefix) :]: value for key, value in inputs.items() if key.startswith(prefix)})
        labels = inputs.get("label", None)
        return features, labels

    def evaluate(
        self,
        eval_dataset: Dataset | dict[str, Dataset] | None = None,
        ignore_keys: list[str] | None = None,
        metric_key_prefix: str = "eval",
    ) -> dict[str, float]:
        if eval_dataset is not None:
            eval_dataset = self.preprocess_dataset(eval_dataset, dataset_name="eval")
        else:
            eval_dataset = self.eval_dataset
        return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)

    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: bool | None = None,
        ignore_keys: list[str] | None = None,
        metric_key_prefix: str = "eval",
    ) -> EvalLoopOutput:
        output = super().evaluation_loop(
            dataloader=dataloader,
            description=description,
            prediction_loss_only=prediction_loss_only,
            ignore_keys=ignore_keys,
            metric_key_prefix=metric_key_prefix,
        )

        # If the evaluator is not defined, we can just return the output
        if self.evaluator is None:
            return output

        # If we are training and eval_dataset is a DatasetDict, then we should
        # 1) only run the evaluator for the first dataset
        # 2) prefix that only run as "eval", rather than e.g. "eval_multi_nli"
        if self.is_in_train and isinstance(self.eval_dataset, dict) and metric_key_prefix.startswith("eval_"):
            if metric_key_prefix[5:] == list(self.eval_dataset.keys())[0]:
                metric_key_prefix = "eval"
            else:
                return output

        with nullcontext() if self.is_local_process_zero() else disable_logging(logging.INFO):
            output_path = self.args.output_dir
            if output_path is not None:
                output_path = os.path.join(output_path, "eval")
                os.makedirs(output_path, exist_ok=True)
            evaluator_metrics = self.evaluator(
                self.model, output_path=output_path, epoch=self.state.epoch, steps=self.state.global_step
            )
        if not isinstance(evaluator_metrics, dict):
            evaluator_metrics = {"evaluator": evaluator_metrics}

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(evaluator_metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                evaluator_metrics[f"{metric_key_prefix}_{key}"] = evaluator_metrics.pop(key)

        output.metrics.update(evaluator_metrics)

        return output

    def _load_best_model(self) -> None:
        # Attempt to load the model from self.state.best_model_checkpoint
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")

        try:
            if checkpoint := self.state.best_model_checkpoint:
                step = checkpoint.rsplit("-", 1)[-1]
                self.model.model_card_data.set_best_model_step(int(step))
        except Exception:
            pass

        try:
            self._load_from_checkpoint(self.state.best_model_checkpoint)
        except Exception as exc:
            logger.error(f"Could not load the best model from {self.state.best_model_checkpoint}. Error: {str(exc)}")
            return

    def validate_column_names(self, dataset: Dataset, dataset_name: str | None = None) -> None:
        if isinstance(dataset, dict):
            for dataset_name, dataset in dataset.items():
                self.validate_column_names(dataset, dataset_name=dataset_name)
            return

        if overlap := set(dataset.column_names) & {"return_loss", "dataset_name"}:
            raise ValueError(
                f"The following column names are invalid in your {dataset_name + ' ' if dataset_name else ''}dataset: {list(overlap)}."
                " Avoid using these column names, as they are reserved for internal use."
            )

    def get_batch_sampler(
        self,
        dataset: Dataset,
        batch_size: int,
        drop_last: bool,
        valid_label_columns: list[str] | None = None,
        generator: torch.Generator | None = None,
        seed: int = 0,
    ) -> BatchSampler | None:
        """
        Returns the appropriate batch sampler based on the ``batch_sampler`` argument in ``self.args``.
        This batch sampler class supports ``__len__`` and ``__iter__`` methods, and is used as the ``batch_sampler``
        to create the :class:`torch.utils.data.DataLoader`.

        .. note::
            Override this method to provide a custom batch sampler.

        Args:
            dataset (Dataset): The dataset to sample from.
            batch_size (int): Number of samples per batch.
            drop_last (bool): If True, drop the last incomplete batch if the dataset size
                is not divisible by the batch size.
            valid_label_columns (List[str]): List of column names to check for labels.
                The first column name from ``valid_label_columns`` found in the dataset will
                be used as the label column.
            generator (torch.Generator, optional): Optional random number generator for shuffling
                the indices.
            seed (int): Seed for the random number generator to ensure reproducibility. Defaults to 0.
        """

        batch_sampler_kwargs = {
            "batch_size": batch_size,
            "drop_last": drop_last,
            "valid_label_columns": valid_label_columns,
            "generator": generator,
            "seed": seed,
        }
        # If the batch sampler is a DefaultBatchSampler subclass, initialize it
        if inspect.isclass(self.args.batch_sampler) and issubclass(self.args.batch_sampler, DefaultBatchSampler):
            return self.args.batch_sampler(dataset, **batch_sampler_kwargs)

        # If it's a callable, call it
        if callable(self.args.batch_sampler):
            return self.args.batch_sampler(dataset, **batch_sampler_kwargs)

        # Otherwise it's a BatchSamplers enum. None of those work with IterableDatasets, so we
        # don't use them in that case
        if isinstance(dataset, IterableDataset):
            if self.args.batch_sampler != BatchSamplers.BATCH_SAMPLER:
                logger.warning("When using an IterableDataset, you cannot specify a batch sampler.")
            return None

        # Lastly, use the samplers that match the enum values
        if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES:
            return NoDuplicatesBatchSampler(dataset, **batch_sampler_kwargs)

        if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES_HASHED:
            return NoDuplicatesBatchSampler(dataset, precompute_hashes=True, **batch_sampler_kwargs)

        if self.args.batch_sampler == BatchSamplers.GROUP_BY_LABEL:
            return GroupByLabelBatchSampler(dataset, **batch_sampler_kwargs)

        if self.args.batch_sampler == BatchSamplers.BATCH_SAMPLER:
            return DefaultBatchSampler(RandomSampler(dataset, generator=generator), **batch_sampler_kwargs)

    def get_multi_dataset_batch_sampler(
        self,
        dataset: ConcatDataset,
        batch_samplers: list[BatchSampler],
        generator: torch.Generator | None = None,
        seed: int | None = 0,
    ) -> BatchSampler:
        """
        Returns the appropriate multi-dataset batch sampler based on the ``multi_dataset_batch_sampler`` argument
        in ``self.args``. This batch sampler class supports ``__len__`` and ``__iter__`` methods, and is used as the
        ``batch_sampler`` to create the :class:`torch.utils.data.DataLoader`.

        .. note::
            Override this method to provide a custom multi-dataset batch sampler.

        Args:
            dataset (ConcatDataset): The concatenation of all datasets.
            batch_samplers (List[BatchSampler]): List of batch samplers for each dataset in the concatenated dataset.
            generator (torch.Generator, optional): Optional random number generator for shuffling the indices.
            seed (int, optional): Optional seed for the random number generator
        """

        multi_batch_sampler_kwargs = {
            "batch_samplers": batch_samplers,
            "generator": generator,
            "seed": seed,
        }
        # If the multi-dataset batch sampler is a DefaultBatchSampler subclass, initialize it
        if inspect.isclass(self.args.multi_dataset_batch_sampler) and issubclass(
            self.args.multi_dataset_batch_sampler, MultiDatasetDefaultBatchSampler
        ):
            return self.args.multi_dataset_batch_sampler(dataset, **multi_batch_sampler_kwargs)

        # If it's a callable, call it
        if callable(self.args.multi_dataset_batch_sampler):
            return self.args.multi_dataset_batch_sampler(dataset, **multi_batch_sampler_kwargs)

        # Otherwise, it's an MultiDatasetBatchSamplers instance and we use the samplers that match the enum values
        if self.args.multi_dataset_batch_sampler == MultiDatasetBatchSamplers.ROUND_ROBIN:
            return RoundRobinBatchSampler(dataset=dataset, **multi_batch_sampler_kwargs)

        if self.args.multi_dataset_batch_sampler == MultiDatasetBatchSamplers.PROPORTIONAL:
            return ProportionalBatchSampler(dataset=dataset, **multi_batch_sampler_kwargs)

    def _build_dataloader(
        self,
        dataset: Dataset | DatasetDict | IterableDataset,
        batch_size: int,
        dataset_kind: str,
    ) -> DataLoader:
        """Shared logic for building train/eval/test DataLoaders.

        Args:
            dataset: The dataset to build a DataLoader for.
            batch_size: The batch size to use.
            dataset_kind: A label for error messages, e.g. "train", "eval", or "test".

        Returns:
            A prepared DataLoader for the given dataset.
        """
        data_collator = self.data_collator

        generator = torch.Generator()
        if self.args.seed is not None:
            generator.manual_seed(self.args.seed)

        dataloader_params = {
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
            "prefetch_factor": self.args.dataloader_prefetch_factor,
        }

        if isinstance(dataset, IterableDataset):
            dataloader_params.update(
                {
                    "batch_size": batch_size,
                    "drop_last": self.args.dataloader_drop_last,
                }
            )
            if self.args.batch_sampler != BatchSamplers.BATCH_SAMPLER:
                logger.warning("When using an IterableDataset, you cannot specify a batch sampler.")

        elif isinstance(dataset, IterableDatasetDict):
            raise ValueError(
                "Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead."
            )

        elif isinstance(dataset, DatasetDict):
            for sub_dataset in dataset.values():
                if isinstance(sub_dataset, IterableDataset):
                    raise ValueError(
                        "Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
                    )

            batch_samplers = [
                self.get_batch_sampler(
                    sub_dataset,
                    batch_size=batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    valid_label_columns=data_collator.valid_label_columns,
                    generator=generator,
                )
                for sub_dataset in dataset.values()
            ]

            dataset = ConcatDataset(dataset.values())
            batch_sampler = self.get_multi_dataset_batch_sampler(
                dataset=dataset,
                batch_samplers=batch_samplers,
                generator=generator,
                seed=self.args.seed,
            )
            dataloader_params["batch_sampler"] = batch_sampler

        elif isinstance(dataset, Dataset):
            batch_sampler = self.get_batch_sampler(
                dataset,
                batch_size=batch_size,
                drop_last=self.args.dataloader_drop_last,
                valid_label_columns=data_collator.valid_label_columns,
                generator=generator,
            )
            dataloader_params["batch_sampler"] = batch_sampler
        else:
            raise ValueError(
                f"Unsupported `{dataset_kind}_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for {dataset_kind}."
            )

        return DataLoader(dataset, **dataloader_params)

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training [`~torch.utils.data.DataLoader`].

        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError(f"Training requires specifying a train_dataset to the {self.__class__.__name__}.")

        # If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
        # cause issues with multi-dataset training, so we want to set this to False.
        # For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True there.
        self.accelerator.even_batches = False
        self._train_dataloader = self.accelerator.prepare(
            self._build_dataloader(self.train_dataset, self.args.train_batch_size, dataset_kind="train")
        )
        return self._train_dataloader

    def get_eval_dataloader(self, eval_dataset: Dataset | DatasetDict | IterableDataset | None = None) -> DataLoader:
        """
        Returns the evaluation [`~torch.utils.data.DataLoader`].

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
        """
        if eval_dataset is None and self.eval_dataset is None:
            # Prevent errors if the evaluator is set but no eval_dataset is provided
            if self.evaluator is not None:
                return DataLoader([])
            raise ValueError(f"Evaluation requires specifying an eval_dataset to the {self.__class__.__name__}.")

        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        # If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
        # cause issues with multi-dataset training, so we want to set this to False during training.
        # For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True here.
        self.accelerator.even_batches = True
        return self.accelerator.prepare(
            self._build_dataloader(eval_dataset, self.args.eval_batch_size, dataset_kind="eval")
        )

    def get_test_dataloader(self, test_dataset: Dataset | DatasetDict | IterableDataset) -> DataLoader:
        """
        Returns the test [`~torch.utils.data.DataLoader`].

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            test_dataset (`torch.utils.data.Dataset`, *optional*):
                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
                `model.forward()` method are automatically removed. It must implement `__len__`.
        """
        # If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
        # cause issues with multi-dataset training, so we want to set this to False during training.
        # For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True here.
        self.accelerator.even_batches = True
        return self.accelerator.prepare(
            self._build_dataloader(test_dataset, self.args.eval_batch_size, dataset_kind="test")
        )

    def _save(self, output_dir: str | None = None, state_dict=None) -> None:
        # If we are executing this function, we are the process zero, so we don't check for that.
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")

        # Transformers v5.0.0 removed the `save_safetensors` argument from the Training Arguments,
        # so we check for its existence first
        if hasattr(self.args, "save_safetensors"):
            self.model.save_pretrained(output_dir, safe_serialization=self.args.save_safetensors)
        else:
            self.model.save_pretrained(output_dir)

        # Transformers v4.46.0 changed the `tokenizer` attribute to a more general `processing_class` attribute
        if self.processing_class is not None:
            self.processing_class.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

    def _load_from_checkpoint(self, checkpoint_path: str) -> None:
        model_class = self.model.__class__
        loaded_model = model_class(checkpoint_path, trust_remote_code=self.model.trust_remote_code)
        self.model.load_state_dict(loaded_model.state_dict())

    def preprocess_dataset(
        self, dataset: DatasetDict | Dataset | None = None, dataset_name: str | None = None
    ) -> DatasetDict | Dataset | None:
        """
        Preprocess the dataset by optionally lazily adding a dataset name column, required for multi-dataset training
        with multiple losses or for dataset-specific router mappings.

        Args:
            dataset (DatasetDict | Dataset | None): The dataset to preprocess. If None, no preprocessing is done.
            dataset_name (str | None): The name of the dataset, used for multi-dataset training with multiple losses.

        Returns:
            DatasetDict | Dataset | None: The preprocessed dataset, perhaps with dataset names added as a lazy column.
        """
        # If we've already added the transform to this (iterable) dataset, don't add it again
        if hasattr(dataset, "_sentence_transformers_preprocessed") or dataset is None:
            return dataset

        # Validate that reserved column names are not used in the dataset
        self.validate_column_names(dataset, dataset_name=dataset_name)

        if self.should_dataset_name_column_be_added(dataset, self.args, self.loss):
            dataset = self.add_dataset_name_column(dataset)

        # Add a tag to the dataset to indicate that it has been preprocessed, to ensure that we don't apply the map or
        # transform multiple times.
        dataset._sentence_transformers_preprocessed = True

        return dataset

    def should_dataset_name_column_be_added(
        self,
        dataset: DatasetDict | Dataset | None,
        args: BaseTrainingArguments,
        loss: nn.Module | dict[str, nn.Module],
    ) -> bool:
        """
        We should add a dataset name column to the dataset, if the dataset is a DatasetDict, *and* one of:

        a. The loss is a dictionary, or
        b. The prompts contain a mapping of dataset names, or
        c. The router_mapping contains a mapping of dataset names.
        """
        return isinstance(dataset, (DatasetDict, IterableDatasetDict)) and (
            isinstance(loss, dict)
            or (args.prompts and isinstance(args.prompts, dict))
            or (
                args.router_mapping
                and isinstance(args.router_mapping, dict)
                and isinstance(next(iter(args.router_mapping.values())), dict)
            )
        )

    def add_dataset_name_column(
        self,
        dataset: DatasetDict | IterableDatasetDict | Dataset | IterableDataset,
        dataset_name: str | None = None,
    ) -> DatasetDict | Dataset | None:
        if isinstance(dataset, (IterableDatasetDict, DatasetDict)):
            for dataset_name, inner_dataset in dataset.items():
                dataset[dataset_name] = self.add_dataset_name_column(
                    dataset=inner_dataset,
                    dataset_name=dataset_name,
                )
            return dataset

        # If the dataset name is None, we don't need to do anything
        if dataset_name is None:
            return dataset

        # If we have a Dataset, we can set the transform directly...
        if isinstance(dataset, Dataset):
            dataset.set_transform(
                partial(
                    self.add_dataset_name_transform,
                    dataset_name=dataset_name,
                    **dataset._format_kwargs,
                )
            )

        # ... otherwise, we have an IterableDataset and we need to map it, which performs the same operation as above
        elif isinstance(dataset, IterableDataset):
            # Update the features to include the new columns
            features = dataset.features
            if dataset_name:
                features["dataset_name"] = Value("string")

            dataset = dataset.map(
                partial(
                    self.add_dataset_name_transform,
                    dataset_name=dataset_name,
                ),
                batched=True,
                features=features,
            )
        else:
            raise ValueError(
                "Unsupported `dataset` type. Use a Dataset, DatasetDict, IterableDataset, or IterableDatasetDict."
            )
        return dataset

    @staticmethod
    def add_dataset_name_transform(
        batch: dict[str, list[Any]],
        dataset_name: str | None = None,
        transform: Callable[[dict[str, list[Any]]], dict[str, list[Any]]] | None = None,
        **kwargs,
    ) -> dict[str, list[Any]]:
        """A transform/map function that adds the dataset name to the batch.

        Args:
            batch (dict[str, list[Any]]): The batch of data, where each key is a column name and each value
                is a list of values.
            dataset_name (str | None, optional): The name of this dataset, only if there are multiple datasets
                that use a different loss. Defaults to None.
            transform (Callable[[dict[str, list[Any]]], dict[str, list[Any]]], optional): An optional transform
                function to apply on the batch before adding the dataset name. Defaults to None.

        Returns:
            dict[str, list[Any]]: The "just-in-time" transformed batch with the dataset name added.
        """
        # If the dataset is a Dataset(Dict), then we use set_transform and we want to also apply any
        # previous transform if it exists
        if transform:
            batch = transform(batch)

        # Return if 1) the batch has no columns, 2) if it's empty, or 3) if there is no dataset name
        if not batch or not list(batch.values())[0] or dataset_name is None:
            return batch

        # Add the dataset name to the batch
        batch_size = len(list(batch.values())[0])
        batch["dataset_name"] = [dataset_name] * batch_size
        return batch

    def create_model_card(
        self,
        language: str | None = None,
        license: str | None = None,
        tags: str | list[str] | None = None,
        model_name: str | None = None,
        finetuned_from: str | None = None,
        tasks: str | list[str] | None = None,
        dataset_tags: str | list[str] | None = None,
        dataset: str | list[str] | None = None,
        dataset_args: str | list[str] | None = None,
        **kwargs,
    ) -> None:
        if not self.is_world_process_zero():
            return

        if language:
            self.model.model_card_data.set_language(language)
        if license:
            self.model.model_card_data.set_license(license)
        if tags:
            self.model.model_card_data.add_tags(tags)

        self.model._create_model_card(self.args.output_dir, model_name=model_name)

    def get_optimizer_cls_and_kwargs(
        self, args: BaseTrainingArguments, model: BaseModel | None = None
    ) -> tuple[Any, Any]:
        """
        We have to override the optimizer_grouped_parameters because the Trainer superclass bases it on the `model`
        itself, but the BaseModel losses can have weights that should be updated as well, e.g.
        SoftmaxLoss (see #2872).

        This method requires `transformers` >= 4.43.0.
        """

        if isinstance(self.loss, dict):
            loss_model = nn.Sequential(OrderedDict(self.loss))
        else:
            loss_model = self.loss
        optimizer_cls, optimizer_kwargs = super().get_optimizer_cls_and_kwargs(args, loss_model)

        # If the kwargs were not overridden by the super() call, then we should override them here so that the potential
        # weights in the loss(es) can also be updated.
        decay_parameters = self.get_decay_parameter_names(loss_model)
        if not {"params", "model", "optimizer_dict"} & set(optimizer_kwargs.keys()):
            optimizer_kwargs["optimizer_dict"] = [
                {
                    "params": [
                        p for n, p in loss_model.named_parameters() if (n in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [
                        p for n, p in loss_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": 0.0,
                },
            ]

        # One of "params", "model", or "optimizer_dict" should be in the optimizer_kwargs
        for parameter_pattern, learning_rate in args.learning_rate_mapping.items():
            # Check which optimizer parameter key is present
            optimizer_param_keys = set(optimizer_kwargs.keys()) & {"params", "model", "optimizer_dict"}
            optimizer_param_key = optimizer_param_keys.pop() if optimizer_param_keys else "optimizer_dict"

            # Get parameters that match the pattern
            matching_params = {n: p for n, p in loss_model.named_parameters() if re.search(parameter_pattern, n)}

            if matching_params:
                # Remove matching parameters from existing optimizer groups
                for group in optimizer_kwargs[optimizer_param_key]:
                    if "params" in group:
                        group["params"] = [
                            p for p in group["params"] if all(p is not param for param in matching_params.values())
                        ]
            else:
                raise ValueError(
                    f"No parameters found matching the pattern '{parameter_pattern}' in the model. "
                    "Please check the pattern and ensure it matches some of the model's parameters."
                )

            # Add new optimizer group with matching parameters
            # decay_parameters = self.get_decay_parameter_names(loss_model)
            matching_params_with_decay = {n: p for n, p in matching_params.items() if n in decay_parameters}
            matching_params_without_decay = {n: p for n, p in matching_params.items() if n not in decay_parameters}

            if matching_params_with_decay:
                optimizer_kwargs[optimizer_param_key].append(
                    {
                        "params": list(matching_params_with_decay.values()),
                        "lr": learning_rate,
                        "weight_decay": self.args.weight_decay,
                    }
                )

            if matching_params_without_decay:
                optimizer_kwargs[optimizer_param_key].append(
                    {
                        "params": list(matching_params_without_decay.values()),
                        "lr": learning_rate,
                        "weight_decay": 0.0,
                    }
                )

        return optimizer_cls, optimizer_kwargs
