# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import copy
import time
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload

from typing_extensions import TypeVar

from vllm.inputs import (
    EmbedsInputs,
    EmbedsPrompt,
    EncoderDecoderInputs,
    ProcessorInputs,
    SingletonInputs,
    TextPrompt,
    TokenInputs,
    TokensPrompt,
)
from vllm.inputs.data import build_enc_dec_inputs, embeds_inputs, token_inputs
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.counter import AtomicCounter
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.metrics.stats import MultiModalCacheStats

from .embed_utils import safe_load_prompt_embeds
from .inputs import (
    DictPrompt,
    EncoderDecoderDictPrompt,
    EncoderDecoderTokPrompt,
    SingletonDictPrompt,
    SingletonTokPrompt,
    TokPrompt,
)
from .inputs.preprocess import extract_target_prompt
from .params import ChatParams, TokenizeParams

if TYPE_CHECKING:
    from vllm.config import VllmConfig
    from vllm.entrypoints.chat_utils import (
        ChatCompletionMessageParam,
        ConversationMessage,
    )
    from vllm.multimodal.cache import BaseMultiModalProcessorCache
    from vllm.multimodal.inputs import (
        MultiModalDataDict,
        MultiModalInputs,
        MultiModalUUIDDict,
    )
    from vllm.multimodal.parse import MultiModalDataItems, MultiModalUUIDItems
    from vllm.multimodal.processing import BaseMultiModalProcessor

logger = init_logger(__name__)


_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)


class BaseRenderer(ABC, Generic[_T]):
    @classmethod
    @abstractmethod
    def from_config(
        cls,
        config: "VllmConfig",
        tokenizer_kwargs: dict[str, Any],
    ) -> "BaseRenderer":
        raise NotImplementedError

    def __init__(self, config: "VllmConfig", tokenizer: _T | None) -> None:
        super().__init__()

        self.config = config
        self.model_config = config.model_config
        self.api_process_rank = config.parallel_config._api_process_rank

        self.tokenizer = tokenizer

        # Lazy initialization since offline LLM doesn't use async
        self._async_tokenizer: AsyncMicrobatchTokenizer | None = None

        self.mm_processor: BaseMultiModalProcessor | None = None
        self._mm_cache_stats: MultiModalCacheStats | None = None
        if config.model_config.is_multimodal_model:
            from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
            from vllm.multimodal.registry import MultiModalTimingRegistry

            mm_processor_cache = mm_registry.processor_cache_from_config(config)

            # Deep-copy the tokenizer so the multimodal processor gets its
            # own Rust tokenizer backend.  Without this, concurrent access
            # from AsyncMicrobatchTokenizer and call_hf_processor causes
            # "RuntimeError: Already borrowed" from the Rust RefCell.
            # See: https://github.com/huggingface/tokenizers/issues/537
            mm_tokenizer = copy.deepcopy(tokenizer)

            with set_default_torch_num_threads():
                self.mm_processor = mm_registry.create_processor(
                    config.model_config,
                    tokenizer=mm_tokenizer,
                    cache=mm_processor_cache,
                )

            if mm_processor_cache:
                self._mm_cache_stats = MultiModalCacheStats()

            # This is used to generate internal request ID for MM processing
            # It has no relation to the request ID for engine core
            self._mm_req_counter = AtomicCounter()
            self._mm_timing_registry = MultiModalTimingRegistry(
                config.observability_config
            )

    def get_tokenizer(self) -> _T:
        tokenizer = self.tokenizer
        if tokenizer is None:
            raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")

        return tokenizer

    def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
        if self._async_tokenizer is None:
            self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())

        return self._async_tokenizer

    def get_mm_processor(self) -> "BaseMultiModalProcessor":
        if self.mm_processor is None:
            raise ValueError("Multi-modal processor not available for text-only models")

        return self.mm_processor

    @property
    def mm_processor_cache(self) -> "BaseMultiModalProcessorCache | None":
        if self.mm_processor is None:
            return None

        return self.mm_processor.cache

    def stat_mm_cache(self) -> MultiModalCacheStats | None:
        mm_cache_stats = self._mm_cache_stats
        if mm_cache_stats is None:
            return None

        self._mm_cache_stats = MultiModalCacheStats()

        return mm_cache_stats

    def update_mm_cache_stats(self) -> None:
        mm_processor_cache = self.mm_processor_cache
        mm_cache_stats = self._mm_cache_stats

        if mm_processor_cache and mm_cache_stats:
            delta = mm_processor_cache.make_stats(delta=True)
            mm_cache_stats.record(delta.total, delta.hits)

    def clear_mm_cache(self) -> None:
        mm_processor_cache = self.mm_processor_cache
        if mm_processor_cache is not None:
            mm_processor_cache.clear_cache()

        if self._mm_cache_stats is not None:
            self._mm_cache_stats.reset = True

    def warmup(self, chat_params: ChatParams) -> None:
        """
        Warm up this renderer to avoid first-request latency.

        For chat requests:
        - Jinja2 template compilation

        For multi-modal requests:
        - Importing libraries such as librosa triggers JIT compilation.
        """
        from vllm.entrypoints.chat_utils import ChatTemplateResolutionError

        try:
            logger.debug("Warming up chat template processing...")
            start_time = time.perf_counter()

            self.render_chat([[{"role": "user", "content": "warmup"}]], chat_params)

            elapsed = time.perf_counter() - start_time
            logger.debug("Chat template warmup completed in %.3fs", elapsed)
        except ChatTemplateResolutionError:
            logger.debug("This model does not support chat template.")
        except Exception:
            logger.warning("Chat template warmup failed", exc_info=True)

        if self.mm_processor:
            from vllm.multimodal.processing import TimingContext

            model_config = self.model_config
            mm_config = model_config.get_multimodal_config()
            processor = self.mm_processor
            mm_limits = processor.info.allowed_mm_limits

            try:
                logger.debug("Warming up multi-modal processing...")
                start_time = time.perf_counter()

                processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
                    seq_len=model_config.max_model_len,
                    mm_counts=dict.fromkeys(mm_limits, 1),
                    mm_options=mm_config.limit_per_prompt,
                )
                _ = processor.apply(
                    processor_inputs, timing_ctx=TimingContext(enabled=False)
                )

                elapsed = time.perf_counter() - start_time
                logger.info("Multi-modal warmup completed in %.3fs", elapsed)
            except Exception:
                logger.warning("Multi-modal warmup failed")
            finally:
                self.clear_mm_cache()

    def shutdown(self) -> None:
        mm_processor_cache = self.mm_processor_cache
        if mm_processor_cache is not None:
            mm_processor_cache.close()

    def get_bos_token_id(self) -> int | None:
        if self.tokenizer is None:
            logger.warning_once(
                "Using None for BOS token id because tokenizer is not initialized"
            )
            return None

        return self.tokenizer.bos_token_id

    def get_eos_token_id(self) -> int | None:
        if self.tokenizer is None:
            logger.warning_once(
                "Using None for EOS token id because tokenizer is not initialized"
            )
            return None

        return self.tokenizer.eos_token_id

    def get_dec_start_token_id(self) -> int:
        """
        Obtain the decoder start token id employed by an encoder/decoder model,
        raising an error if it is not available.
        """
        dec_start_token_id = getattr(
            self.model_config.hf_config, "decoder_start_token_id", None
        )

        if dec_start_token_id is None:
            logger.warning_once(
                "Falling back on <BOS> for decoder start token id "
                "because decoder start token id is not available."
            )
            dec_start_token_id = self.get_bos_token_id()

        if dec_start_token_id is None:
            raise RuntimeError("Cannot find decoder start token id or <BOS>")

        return dec_start_token_id

    @cached_property
    def default_cmpl_tok_params(self) -> TokenizeParams:
        mm_processor = self.mm_processor
        if mm_processor is not None:
            return mm_processor.info.default_tok_params

        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=True,
        )

    @cached_property
    def default_chat_tok_params(self) -> TokenizeParams:
        mm_processor = self.mm_processor
        if mm_processor is not None:
            return mm_processor.info.default_tok_params

        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=False,
        )

    # Step 1: Convert raw inputs to prompts
    def render_prompt(
        self,
        prompt: DictPrompt | bytes,
    ) -> DictPrompt:
        if isinstance(prompt, bytes):
            embeds = safe_load_prompt_embeds(self.model_config, prompt)
            prompt = EmbedsPrompt(prompt_embeds=embeds)

        return prompt

    def render_prompts(
        self,
        prompts: Sequence[DictPrompt | bytes],
    ) -> list[DictPrompt]:
        if len(prompts) == 0:
            raise ValueError("You must pass at least one prompt")

        return [self.render_prompt(prompt) for prompt in prompts]

    async def render_prompts_async(
        self,
        prompts: Sequence[DictPrompt | bytes],
    ) -> list[DictPrompt]:
        return self.render_prompts(prompts)

    @abstractmethod
    def render_messages(
        self,
        messages: list["ChatCompletionMessageParam"],
        params: ChatParams,
    ) -> tuple[list["ConversationMessage"], DictPrompt]:
        raise NotImplementedError

    async def render_messages_async(
        self,
        messages: list["ChatCompletionMessageParam"],
        params: ChatParams,
    ) -> tuple[list["ConversationMessage"], DictPrompt]:
        return self.render_messages(messages, params)

    # Step 2: Tokenize prompts if necessary
    def _tokenize_prompt(
        self,
        prompt: TextPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt:
        tokenizer = self.get_tokenizer()
        prompt_token_ids = tokenizer.encode(
            prompt["prompt"],
            **params.get_encode_kwargs(),
        )

        return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)

    async def _tokenize_prompt_async(
        self,
        prompt: TextPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt:
        tokenizer = self.get_async_tokenizer()
        prompt_token_ids = await tokenizer.encode(
            prompt["prompt"],
            **params.get_encode_kwargs(),
        )

        return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)

    def _detokenize_prompt(self, prompt: TokensPrompt) -> TokensPrompt:
        tokenizer = self.get_tokenizer()
        prompt["prompt"] = tokenizer.decode(prompt["prompt_token_ids"])

        return prompt

    async def _detokenize_prompt_async(self, prompt: TokensPrompt) -> TokensPrompt:
        tokenizer = self.get_async_tokenizer()
        prompt["prompt"] = await tokenizer.decode(prompt["prompt_token_ids"])

        return prompt

    @overload
    def _tokenize_singleton_prompt(
        self,
        prompt: TextPrompt | TokensPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt: ...

    @overload
    def _tokenize_singleton_prompt(  # type: ignore[misc]
        self,
        prompt: EmbedsPrompt,
        params: TokenizeParams,
    ) -> EmbedsPrompt: ...

    def _tokenize_singleton_prompt(
        self,
        prompt: SingletonDictPrompt,
        params: TokenizeParams,
    ) -> SingletonTokPrompt:
        if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
            prompt = params.apply_pre_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]
            prompt = self._tokenize_prompt(prompt, params)

        if params.needs_detokenization and "prompt" not in prompt:
            if "prompt_token_ids" not in prompt:
                raise RuntimeError("Cannot run detokenization on embeddings")

            prompt = self._detokenize_prompt(prompt)  # type: ignore[arg-type]

        return params.apply_post_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]

    @overload
    async def _tokenize_singleton_prompt_async(
        self,
        prompt: TextPrompt | TokensPrompt,
        params: TokenizeParams,
    ) -> TokensPrompt: ...

    @overload
    async def _tokenize_singleton_prompt_async(  # type: ignore[misc]
        self,
        prompt: EmbedsPrompt,
        params: TokenizeParams,
    ) -> EmbedsPrompt: ...

    async def _tokenize_singleton_prompt_async(
        self,
        prompt: SingletonDictPrompt,
        params: TokenizeParams,
    ) -> SingletonTokPrompt:
        if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
            prompt = params.apply_pre_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]
            prompt = await self._tokenize_prompt_async(prompt, params)

        if params.needs_detokenization and "prompt" not in prompt:
            if "prompt_token_ids" not in prompt:
                raise RuntimeError("Cannot run detokenization on embeddings")

            prompt = await self._detokenize_prompt_async(prompt)  # type: ignore[arg-type]

        return params.apply_post_tokenization(self.tokenizer, prompt)  # type: ignore[arg-type]

    def _tokenize_enc_dec_prompt(
        self,
        prompt: EncoderDecoderDictPrompt,
        params: TokenizeParams,
    ) -> EncoderDecoderTokPrompt:
        enc_prompt, dec_prompt = (
            self._tokenize_singleton_prompt(prompt["encoder_prompt"], params),
            (
                None
                if prompt["decoder_prompt"] is None
                else self._tokenize_singleton_prompt(prompt["decoder_prompt"], params)
            ),
        )

        return EncoderDecoderTokPrompt(
            encoder_prompt=enc_prompt,
            decoder_prompt=dec_prompt,
        )

    async def _tokenize_enc_dec_prompt_async(
        self,
        prompt: EncoderDecoderDictPrompt,
        params: TokenizeParams,
    ) -> EncoderDecoderTokPrompt:
        enc_prompt, dec_prompt = await asyncio.gather(
            self._tokenize_singleton_prompt_async(prompt["encoder_prompt"], params),
            (
                asyncio.sleep(0)
                if prompt["decoder_prompt"] is None
                else self._tokenize_singleton_prompt_async(
                    prompt["decoder_prompt"], params
                )
            ),
        )

        return EncoderDecoderTokPrompt(
            encoder_prompt=enc_prompt,
            decoder_prompt=dec_prompt,
        )

    def tokenize_prompt(
        self,
        prompt: DictPrompt,
        params: TokenizeParams,
    ) -> TokPrompt:
        if "encoder_prompt" in prompt:
            return self._tokenize_enc_dec_prompt(prompt, params)  # type: ignore[arg-type]

        return self._tokenize_singleton_prompt(prompt, params)

    def tokenize_prompts(
        self,
        prompts: Sequence[DictPrompt],
        params: TokenizeParams,
    ) -> list[TokPrompt]:
        return [self.tokenize_prompt(prompt, params) for prompt in prompts]

    async def tokenize_prompt_async(
        self,
        prompt: DictPrompt,
        params: TokenizeParams,
    ) -> TokPrompt:
        if "encoder_prompt" in prompt:
            return await self._tokenize_enc_dec_prompt_async(prompt, params)  # type: ignore[arg-type]

        return await self._tokenize_singleton_prompt_async(prompt, params)

    async def tokenize_prompts_async(
        self,
        prompts: Sequence[DictPrompt],
        params: TokenizeParams,
    ) -> list[TokPrompt]:
        return await asyncio.gather(
            *(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
        )

    # Step 3: Add extra keys to the prompts
    def _apply_prompt_extras(
        self,
        prompts: Sequence[TokPrompt],
        prompt_extras: dict[str, Any] | None,
    ):
        if not prompt_extras:
            return

        for prompt in prompts:
            target_prompt = extract_target_prompt(self.model_config, prompt)
            target_prompt.update(prompt_extras)  # type: ignore[arg-type]

    # Step 4: Convert to engine inputs
    def _validate_mm_uuids(
        self,
        mm_data: "MultiModalDataDict",
        mm_data_items: "MultiModalDataItems",
        mm_uuid_items: "MultiModalUUIDItems",
    ) -> None:
        # NOTE: Keys corresponding to `None` in `mm_data` don't appear in
        # `mm_data_items`
        modalities = mm_data.keys() | mm_uuid_items.keys()

        for modality in modalities:
            data_items = mm_data_items.get(modality)
            uuid_items = mm_uuid_items.get(modality)

            if data_items is None:
                if uuid_items is None:
                    raise ValueError(
                        f"multi_modal_data[{modality!r}] is empty but "
                        f"multi_modal_uuids[{modality!r}] is missing."
                    )

            elif uuid_items is not None:
                if len(data_items) != len(uuid_items):
                    raise ValueError(
                        f"If given, multi_modal_uuids[{modality!r}] must have "
                        f"same length as multi_modal_data[{modality!r}], but "
                        f"got {len(uuid_items)} vs {len(data_items)}."
                    )

                for i, item in enumerate(data_items):
                    if item is None and uuid_items[i] is None:
                        raise ValueError(
                            f"multi_modal_data[{modality!r}][{i}] is empty but "
                            f"multi_modal_uuids[{modality!r}][{i}] is missing."
                        )

    def _process_mm_uuids(
        self,
        mm_data: "MultiModalDataDict",
        mm_data_items: "MultiModalDataItems",
        mm_uuid_items: "MultiModalUUIDItems",
        mm_req_id: str,
    ):
        model_config = self.model_config

        # NOTE: When users explicitly turn off BOTH prefix caching and input
        # processing caching, no multimodal features or embeddings will be
        # reused across requests, therefore identifying multimodal data items
        # by their content is no longer necessary, and we create uuids with
        # `<mm_req_id>-<modality>-<index>`, overriding even user-provided ones.
        if (
            model_config.multimodal_config
            and model_config.multimodal_config.mm_processor_cache_gb == 0
            and not self.config.cache_config.enable_prefix_caching
        ):
            mm_uuid_items = {
                modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)]
                for modality, data_count in mm_data_items.get_all_counts().items()
            }

        self._validate_mm_uuids(mm_data, mm_data_items, mm_uuid_items)

        return mm_uuid_items

    # TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor
    def _process_multimodal(
        self,
        prompt: list[int] | str,
        mm_data: "MultiModalDataDict",
        mm_uuids: "MultiModalUUIDDict | None",
        mm_processor_kwargs: Mapping[str, object] | None,
        tokenization_kwargs: dict[str, Any] | None,
    ) -> "MultiModalInputs":
        from vllm.multimodal.parse import parse_mm_uuids
        from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs

        mm_req_id = f"renderer{self.api_process_rank}-mm-{self._mm_req_counter.inc(1)}"

        mm_processor = self.get_mm_processor()

        mm_data_items = mm_processor.info.parse_mm_data(mm_data)
        mm_uuid_items = parse_mm_uuids(mm_uuids)

        mm_uuid_items = self._process_mm_uuids(
            mm_data, mm_data_items, mm_uuid_items, mm_req_id
        )

        mm_processor_inputs = MMProcessorInputs(
            prompt,
            mm_data_items,
            mm_uuid_items,
            hf_processor_mm_kwargs=mm_processor_kwargs or {},
            tokenization_kwargs=tokenization_kwargs or {},
        )
        mm_timing_ctx = self._mm_timing_registry.get(mm_req_id)

        with set_default_torch_num_threads():
            mm_inputs = mm_processor.apply(mm_processor_inputs, mm_timing_ctx)

        self.update_mm_cache_stats()

        return mm_inputs

    def _process_tokens(
        self,
        prompt: TokensPrompt,
    ) -> "TokenInputs | MultiModalInputs":
        prompt_token_ids = prompt["prompt_token_ids"]

        inputs: TokenInputs | MultiModalInputs
        if multi_modal_data := prompt.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_token_ids,
                multi_modal_data,
                mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
                tokenization_kwargs=None,  # Tokenization already done in Step 2
                mm_uuids=prompt.get("multi_modal_uuids"),
            )
        else:
            inputs = token_inputs(prompt_token_ids)

        if prompt_text := prompt.get("prompt"):
            inputs["prompt"] = prompt_text
        if cache_salt := prompt.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_embeds(
        self,
        prompt: EmbedsPrompt,
    ) -> EmbedsInputs:
        if not self.model_config.enable_prompt_embeds:
            raise ValueError(
                "You must set `--enable-prompt-embeds` to input `prompt_embeds`."
            )

        prompt_embeds = prompt["prompt_embeds"]

        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
            raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")

        # Tensors must be on CPU for serialization between processes
        # in the MsgpackEncoder. Casting to CPU here ensures that there is no
        # hidden device transfer in the critical path of generation.
        prompt_embeds = prompt_embeds.cpu()

        return embeds_inputs(
            prompt_embeds=prompt_embeds,
            cache_salt=prompt.get("cache_salt"),
        )

    def _process_singleton(
        self,
        prompt: SingletonTokPrompt,
    ) -> SingletonInputs:
        if "prompt_embeds" in prompt:
            return self._process_embeds(prompt)  # type: ignore[arg-type]

        return self._process_tokens(prompt)  # type: ignore[arg-type]

    def _process_enc_dec(
        self,
        prompt: EncoderDecoderTokPrompt,
    ) -> EncoderDecoderInputs:
        enc_prompt = prompt["encoder_prompt"]
        dec_prompt = prompt["decoder_prompt"]

        return build_enc_dec_inputs(
            encoder_inputs=self._process_singleton(enc_prompt),
            decoder_inputs=(
                None if dec_prompt is None else self._process_singleton(dec_prompt)
            ),
            decoder_start_token_id=self.get_dec_start_token_id(),
        )

    def process_for_engine(
        self, prompt: TokPrompt, arrival_time: float
    ) -> ProcessorInputs:
        engine_prompt: ProcessorInputs
        if "encoder_prompt" in prompt:
            engine_prompt = self._process_enc_dec(prompt)  # type: ignore[arg-type]
        else:
            engine_prompt = self._process_singleton(prompt)

        engine_prompt["arrival_time"] = arrival_time

        return engine_prompt

    # Top-level methods
    def render_cmpl(
        self,
        prompts: Sequence[DictPrompt | bytes],
        tok_params: TokenizeParams | None = None,
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
        arrival_time = time.time()

        if tok_params is None:
            tok_params = self.default_cmpl_tok_params

        dict_prompts = self.render_prompts(prompts)
        tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

        return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]

    async def render_cmpl_async(
        self,
        prompts: Sequence[DictPrompt | bytes],
        tok_params: TokenizeParams | None = None,
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
        arrival_time = time.time()

        if tok_params is None:
            tok_params = self.default_cmpl_tok_params

        dict_prompts = await self.render_prompts_async(prompts)
        tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

        return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]

    def render_chat(
        self,
        conversations: Sequence[list["ChatCompletionMessageParam"]],
        chat_params: ChatParams,
        tok_params: TokenizeParams | None = None,
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
        arrival_time = time.time()

        if tok_params is None:
            tok_params = self.default_chat_tok_params

        rendered = [
            self.render_messages(conversation, chat_params)
            for conversation in conversations
        ]

        out_conversations = list[list["ConversationMessage"]]()
        dict_prompts = list[DictPrompt]()
        for conv, prompt in rendered:
            out_conversations.append(conv)
            dict_prompts.append(prompt)

        tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

        eng_prompts = [
            self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
        ]

        return out_conversations, eng_prompts

    async def render_chat_async(
        self,
        conversations: Sequence[list["ChatCompletionMessageParam"]],
        chat_params: ChatParams,
        tok_params: TokenizeParams | None = None,
        *,
        prompt_extras: dict[str, Any] | None = None,
    ):
        arrival_time = time.time()

        if tok_params is None:
            tok_params = self.default_chat_tok_params

        rendered = [
            self.render_messages_async(conversation, chat_params)
            for conversation in conversations
        ]

        out_conversations = list[list["ConversationMessage"]]()
        dict_prompts = list[DictPrompt]()
        for conv, prompt in await asyncio.gather(*rendered):
            out_conversations.append(conv)
            dict_prompts.append(prompt)

        tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)

        self._apply_prompt_extras(tok_prompts, prompt_extras)

        eng_prompts = [
            self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
        ]

        return out_conversations, eng_prompts
