# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Handler for the /v1/chat/completions endpoint.

Supports streaming (SSE via DirectStreamer) and non-streaming (JSON) responses.
"""

import asyncio
import time
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING

from ...utils import logging
from ...utils.import_utils import is_serve_available


if is_serve_available():
    from fastapi.responses import JSONResponse, StreamingResponse
    from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall
    from openai.types.chat.chat_completion import Choice
    from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall
    from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
    from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming
    from openai.types.completion_usage import CompletionUsage


from .utils import (
    BaseGenerateManager,
    BaseHandler,
    Modality,
    _StreamError,
    get_tool_call_config,
    parse_tool_calls,
)


if TYPE_CHECKING:
    from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin


class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False):
    generation_config: str
    seed: int


# Fields accepted by the OpenAI schema but not yet supported.
# Receiving these raises an error to avoid silent misbehaviour.
# NOTE: "stop" is NOT in this set — we map it to stop_strings.
UNUSED_CHAT_COMPLETION_FIELDS = {
    "audio",
    "function_call",
    "functions",
    "logprobs",
    "max_completion_tokens",
    "metadata",
    "modalities",
    "n",
    "parallel_tool_calls",
    "prediction",
    "presence_penalty",
    "reasoning_effort",
    "response_format",
    "service_tier",
    "store",
    "stream_options",
    "tool_choice",
    "top_logprobs",
    "user",
    "web_search_options",
}


logger = logging.get_logger(__name__)


class ChatCompletionHandler(BaseHandler):
    """Handler for the `/v1/chat/completions` endpoint.

    Supports both streaming (SSE) and non-streaming (JSON) responses.
    """

    _valid_params_class = TransformersCompletionCreateParamsStreaming
    _unused_fields = UNUSED_CHAT_COMPLETION_FIELDS

    async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse:
        """Validate the request, load the model, and dispatch to streaming or non-streaming.

        Args:
            body (`dict`): The raw JSON request body (OpenAI chat completion format).
            request_id (`str`): Unique request identifier (from header or auto-generated).

        Returns:
            `StreamingResponse | JSONResponse`: SSE stream or JSON depending on ``body["stream"]``.
        """
        self._validate_request(body)

        model_id, model, processor = self._resolve_model(body)
        modality = self.model_manager.get_model_modality(model, processor=processor)
        use_cb = self.generation_state.use_continuous_batching(model, modality)
        logger.warning(f"[Request received] Model: {model_id}, CB: {use_cb}")
        gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb)
        processor_inputs = self.get_processor_inputs_from_messages(body["messages"], modality)

        has_video = any(
            c.get("type") == "video"
            for msg in processor_inputs
            for c in (msg.get("content") if isinstance(msg.get("content"), list) else [])
        )
        # Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise
        chat_template_kwargs = {}
        if has_video:
            chat_template_kwargs["num_frames"] = 32
        inputs = processor.apply_chat_template(
            processor_inputs,
            add_generation_prompt=True,
            tools=body.get("tools"),
            return_tensors=None if use_cb else "pt",
            return_dict=True,
            tokenize=True,
            load_audio_from_video=modality == Modality.MULTIMODAL and has_video,
            **chat_template_kwargs,
        )
        if not use_cb:
            inputs = inputs.to(model.device)  # type: ignore[union-attr]

        gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb)
        # TODO: remove when CB supports per-request generation config
        if use_cb:
            gen_manager.init_cb(model, gen_config)

        tool_config = get_tool_call_config(processor, model) if body.get("tools") else None

        streaming = body.get("stream")
        if streaming:
            return self._streaming(
                request_id,
                model,
                processor,
                model_id,
                inputs,
                gen_config,
                gen_manager=gen_manager,
                tool_config=tool_config,
            )
        else:
            return await self._non_streaming(
                request_id,
                model,
                processor,
                model_id,
                inputs,
                gen_config,
                gen_manager=gen_manager,
                tool_config=tool_config,
            )

    # ----- streaming -----

    def _streaming(
        self,
        request_id: str,
        model: "PreTrainedModel",
        processor: "ProcessorMixin | PreTrainedTokenizerFast",
        model_id: str,
        inputs: dict,
        gen_config: "GenerationConfig",
        gen_manager: BaseGenerateManager,
        tool_config: dict | None = None,
    ) -> StreamingResponse:
        """Stream tokens as SSE via DirectStreamer."""
        queue, streamer = gen_manager.generate_streaming(
            model,
            processor,
            inputs,
            gen_config,
            request_id=request_id,
            tool_config=tool_config,
        )
        input_ids = inputs["input_ids"]
        # CB returns plain lists, regular path returns tensors
        input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1]

        async def sse_gen() -> AsyncGenerator[str, None]:
            try:
                yield self._build_chunk_sse(request_id, role="assistant", model=model_id)

                done = False
                while not done:
                    text = await queue.get()
                    batch = [text]
                    try:
                        while True:
                            batch.append(queue.get_nowait())
                    except asyncio.QueueEmpty:
                        pass

                    sse_parts: list[str] = []
                    for text in batch:
                        if text is None:
                            done = True
                            break
                        if isinstance(text, _StreamError):
                            sse_parts.append(f'data: {{"error": "{text.msg}"}}\n\n')
                            yield "".join(sse_parts)
                            return

                        sse_parts.append(self._build_chunk_sse(request_id, model=model_id, content=text))

                    if sse_parts:
                        yield "".join(sse_parts)

                # Tool calls are parsed after generation completes (not during streaming),
                # because the full token sequence is needed for reliable parsing.
                has_tool_calls = False
                if tool_config:
                    parsed = parse_tool_calls(processor, streamer.generated_token_ids, tool_config["schema"])
                    if parsed:
                        has_tool_calls = True
                        for i, tc in enumerate(parsed):
                            yield self._build_chunk_sse(
                                request_id,
                                model=model_id,
                                tool_calls=[
                                    ChoiceDeltaToolCall(
                                        index=i,
                                        type="function",
                                        id=f"{request_id}_tool_call_{i}",
                                        function={"name": tc["name"], "arguments": tc["arguments"]},
                                    )
                                ],
                            )

                hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens
                if has_tool_calls:
                    finish_reason = "tool_calls"
                elif hit_max:
                    finish_reason = "length"
                else:
                    finish_reason = "stop"
                usage = CompletionUsage(
                    prompt_tokens=input_len,
                    completion_tokens=streamer.total_tokens,
                    total_tokens=input_len + streamer.total_tokens,
                )
                yield self._build_chunk_sse(
                    request_id,
                    finish_reason=finish_reason,
                    model=model_id,
                    usage=usage,
                )
            except (GeneratorExit, asyncio.CancelledError):
                # Client disconnected — abort generation to free GPU.
                # Re-raise is mandatory: Python raises RuntimeError if GeneratorExit is swallowed.
                streamer.cancel()
                raise

        return StreamingResponse(sse_gen(), media_type="text/event-stream")

    # ----- non-streaming -----

    async def _non_streaming(
        self,
        request_id: str,
        model: "PreTrainedModel",
        processor: "ProcessorMixin | PreTrainedTokenizerFast",
        model_id: str,
        inputs: dict,
        gen_config: "GenerationConfig",
        gen_manager: BaseGenerateManager,
        tool_config: dict | None = None,
    ) -> JSONResponse:
        """Run generation and return a JSONResponse."""
        content, input_len, generated_ids = await gen_manager.generate_non_streaming(
            model, processor, inputs, gen_config, request_id=request_id
        )

        hit_max = gen_config.max_new_tokens is not None and len(generated_ids) >= gen_config.max_new_tokens
        completion_tokens = len(generated_ids)
        usage = CompletionUsage(
            prompt_tokens=input_len,
            completion_tokens=completion_tokens,
            total_tokens=input_len + completion_tokens,
        )

        tool_calls = None
        if tool_config is not None:
            parsed = parse_tool_calls(processor, generated_ids, tool_config["schema"])
            if parsed:
                tool_calls = [
                    ChatCompletionMessageToolCall(
                        id=f"{request_id}_tool_call_{i}",
                        type="function",
                        function={"name": tc["name"], "arguments": tc["arguments"]},
                    )
                    for i, tc in enumerate(parsed)
                ]

        if tool_calls is not None:
            finish_reason = "tool_calls"
        elif hit_max:
            finish_reason = "length"
        else:
            finish_reason = "stop"

        return JSONResponse(
            self._build_completion(
                request_id,
                content,
                model_id,
                finish_reason=finish_reason,
                usage=usage,
                tool_calls=tool_calls,
            ),
            media_type="application/json",
        )

    # ----- helpers -----

    def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False):
        """Apply Chat Completions params (``max_tokens``, ``frequency_penalty``, ``logit_bias``,
        ``stop``) on top of the base generation config."""
        generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb)

        if body.get("max_tokens") is not None:
            generation_config.max_new_tokens = int(body["max_tokens"])
        if body.get("frequency_penalty") is not None:
            generation_config.repetition_penalty = 1.0 + float(body["frequency_penalty"])
        if body.get("logit_bias") is not None:
            generation_config.sequence_bias = {(int(k),): v for k, v in body["logit_bias"].items()}
        if body.get("stop") is not None:
            generation_config.stop_strings = body["stop"]

        return generation_config

    # ----- response builders -----

    def _build_completion(
        self,
        request_id: str,
        content: str,
        model_id: str,
        finish_reason: str,
        usage: CompletionUsage | None = None,
        tool_calls: list[dict] | None = None,
    ) -> dict:
        """Build a non-streaming ChatCompletion response dict.

        Args:
            request_id (`str`): Unique request identifier.
            content (`str`): The generated text.
            model_id (`str`): Model ID to include in the response.
            finish_reason (`str`): Why generation stopped (``"stop"``, ``"length"``, ``"tool_calls"``).
            usage (`CompletionUsage`, *optional*): Token usage statistics.
            tool_calls (`list[dict]`, *optional*): Parsed tool calls, if any.

        Returns:
            `dict`: Serialized ``ChatCompletion`` ready for JSON response.
        """
        message = ChatCompletionMessage(content=content, role="assistant", tool_calls=tool_calls)
        result = ChatCompletion(
            id=request_id,
            created=int(time.time()),
            object="chat.completion",
            model=model_id,
            choices=[
                Choice(
                    index=0,
                    message=message,
                    finish_reason=finish_reason,
                )
            ],
            usage=usage,
        )
        return result.model_dump(exclude_none=True)

    def _build_chunk_sse(
        self,
        request_id: str = "",
        content: str | None = None,
        model: str | None = None,
        role: str | None = None,
        finish_reason: str | None = None,
        tool_calls: list | None = None,
        usage: CompletionUsage | None = None,
    ) -> str:
        """Build a streaming ``ChatCompletionChunk`` and format it as an SSE ``data:`` line.

        Args:
            request_id (`str`): Unique request identifier.
            content (`str`, *optional*): Text content delta.
            model (`str`, *optional*): Model ID.
            role (`str`, *optional*): Role (only sent in the first chunk).
            finish_reason (`str`, *optional*): Set on the final chunk.
            tool_calls (`list`, *optional*): Tool call deltas.
            usage (`CompletionUsage`, *optional*): Token usage (sent with the final chunk).

        Returns:
            `str`: A formatted SSE event string.
        """
        chunk = ChatCompletionChunk(
            id=request_id,
            created=int(time.time()),
            model=model,
            choices=[
                ChoiceChunk(
                    delta=ChoiceDelta(content=content, role=role, tool_calls=tool_calls),
                    index=0,
                    finish_reason=finish_reason,
                )
            ],
            usage=usage,
            system_fingerprint="",
            object="chat.completion.chunk",
        )
        return self.chunk_to_sse(chunk)
