from collections import OrderedDict
from contextvars import Token
from dataclasses import dataclass, field
from typing import (
    Any,
    Dict,
    List,
    Literal,
    Optional,
    Sequence,
    Set,
    Type,
    Union,
    cast,
)
from uuid import UUID

import pydantic
from opentelemetry import context, trace
from opentelemetry.util._decorator import _AgnosticContextManager

from langfuse import propagate_attributes
from langfuse._client.attributes import LangfuseOtelSpanAttributes
from langfuse._client.client import Langfuse
from langfuse._client.get_client import get_client
from langfuse._client.propagation import _detach_context_token_safely
from langfuse._client.span import (
    LangfuseAgent,
    LangfuseChain,
    LangfuseGeneration,
    LangfuseRetriever,
    LangfuseSpan,
    LangfuseTool,
)
from langfuse._utils import _get_timestamp
from langfuse.langchain.utils import _extract_model_name
from langfuse.logger import langfuse_logger
from langfuse.types import TraceContext

try:
    import langchain

    if langchain.__version__.startswith("1"):
        # Langchain v1
        from langchain_core.agents import AgentAction, AgentFinish
        from langchain_core.callbacks import (
            BaseCallbackHandler as LangchainBaseCallbackHandler,
        )
        from langchain_core.documents import Document
        from langchain_core.messages import (
            AIMessage,
            BaseMessage,
            ChatMessage,
            FunctionMessage,
            HumanMessage,
            SystemMessage,
            ToolMessage,
        )
        from langchain_core.outputs import ChatGeneration, LLMResult

    else:
        # Langchain v0
        from langchain.callbacks.base import (  # type: ignore
            BaseCallbackHandler as LangchainBaseCallbackHandler,
        )
        from langchain.schema.agent import AgentAction, AgentFinish  # type: ignore
        from langchain.schema.document import Document  # type: ignore
        from langchain_core.messages import (
            AIMessage,
            BaseMessage,
            ChatMessage,
            FunctionMessage,
            HumanMessage,
            SystemMessage,
            ToolMessage,
        )
        from langchain_core.outputs import (
            ChatGeneration,
            LLMResult,
        )

except ImportError:
    raise ModuleNotFoundError(
        "Please install langchain to use the Langfuse langchain integration: 'pip install langchain'"
    )

LANGSMITH_TAG_HIDDEN: str = "langsmith:hidden"
CONTROL_FLOW_EXCEPTION_TYPES: Set[Type[BaseException]] = set()
LANGGRAPH_COMMAND_TYPE: Optional[Type[Any]] = None
MAX_PENDING_RESUME_TRACE_CONTEXTS = 1024

try:
    from langgraph.errors import GraphBubbleUp

    CONTROL_FLOW_EXCEPTION_TYPES.add(GraphBubbleUp)
except ImportError:
    pass

try:
    from langgraph.types import Command as LangGraphCommand

    LANGGRAPH_COMMAND_TYPE = LangGraphCommand
except ImportError:
    pass


@dataclass
class _RunState:
    parent_run_id: Optional[UUID]
    root_run_id: UUID


@dataclass
class _RootRunState:
    run_ids: Set[UUID] = field(default_factory=set)
    resume_key: Optional[str] = None
    propagation_context_manager: Optional[_AgnosticContextManager] = None


class _PendingResumeTraceContextStore:
    def __init__(self, max_size: int) -> None:
        self._max_size = max_size
        self._contexts: OrderedDict[str, TraceContext] = OrderedDict()

    def store(self, *, resume_key: str, trace_context: TraceContext) -> None:
        self._contexts[resume_key] = trace_context
        self._contexts.move_to_end(resume_key)

        if len(self._contexts) > self._max_size:
            self._contexts.popitem(last=False)

    def take(self, resume_key: str) -> Optional[TraceContext]:
        return self._contexts.pop(resume_key, None)

    def __contains__(self, resume_key: str) -> bool:
        return resume_key in self._contexts

    def __len__(self) -> int:
        return len(self._contexts)

    def keys(self) -> List[str]:
        return list(self._contexts.keys())


class LangchainCallbackHandler(LangchainBaseCallbackHandler):
    def __init__(
        self,
        *,
        public_key: Optional[str] = None,
        trace_context: Optional[TraceContext] = None,
    ) -> None:
        """Initialize the LangchainCallbackHandler.

        Args:
            public_key: Optional Langfuse public key. If not provided, will use the default client configuration.
            trace_context: Optional context for connecting to an existing trace (distributed tracing) or
                setting a custom trace id for the root LangChain run. Pass a `TraceContext` dict, e.g.
                `{"trace_id": "<trace_id>"}` (and optionally `{"parent_span_id": "<span_id>"}`) to link
                the trace to an upstream system.

        Example:
            Use a custom trace id without context managers:

            ```python
            from langfuse.langchain import CallbackHandler

            handler = CallbackHandler(trace_context={"trace_id": "my-trace-id"})
            ```
        """
        self._langfuse_client = get_client(public_key=public_key)
        self._runs: Dict[
            UUID,
            Union[
                LangfuseSpan,
                LangfuseGeneration,
                LangfuseAgent,
                LangfuseChain,
                LangfuseTool,
                LangfuseRetriever,
            ],
        ] = {}
        self._context_tokens: Dict[UUID, Token] = {}
        self._prompt_to_parent_run_map: Dict[UUID, Any] = {}
        self._updated_completion_start_time_memo: Set[UUID] = set()
        self._trace_context = trace_context
        self._pending_resume_trace_contexts = _PendingResumeTraceContextStore(
            MAX_PENDING_RESUME_TRACE_CONTEXTS
        )
        self._run_states: Dict[UUID, _RunState] = {}
        self._root_run_states: Dict[UUID, _RootRunState] = {}

        self.last_trace_id: Optional[str] = None

    def on_llm_new_token(
        self,
        token: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run on new LLM token. Only available when streaming is enabled."""
        langfuse_logger.debug(
            f"on llm new token: run_id: {run_id} parent_run_id: {parent_run_id}"
        )
        if (
            run_id in self._runs
            and isinstance(self._runs[run_id], LangfuseGeneration)
            and run_id not in self._updated_completion_start_time_memo
        ):
            current_generation = cast(LangfuseGeneration, self._runs[run_id])
            current_generation.update(completion_start_time=_get_timestamp())

            self._updated_completion_start_time_memo.add(run_id)

    def _get_langgraph_resume_key(
        self, metadata: Optional[Dict[str, Any]]
    ) -> Optional[str]:
        thread_id = metadata.get("thread_id") if metadata else None

        if thread_id is None:
            return None

        return str(thread_id)

    def _track_run(
        self,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID],
        metadata: Optional[Dict[str, Any]] = None,
    ) -> None:
        if run_id in self._run_states:
            return

        if parent_run_id is None:
            root_run_id = run_id
            self._root_run_states[root_run_id] = _RootRunState(
                run_ids={run_id},
                resume_key=self._get_langgraph_resume_key(metadata),
            )
        else:
            parent_state = self._run_states.get(parent_run_id)
            root_run_id = (
                parent_state.root_run_id if parent_state is not None else parent_run_id
            )
            root_run_state = self._root_run_states.setdefault(
                root_run_id, _RootRunState()
            )
            root_run_state.run_ids.add(run_id)

        self._run_states[run_id] = _RunState(
            parent_run_id=parent_run_id,
            root_run_id=root_run_id,
        )

    def _get_run_state(self, run_id: UUID) -> Optional[_RunState]:
        return self._run_states.get(run_id)

    def _get_root_run_state(self, run_id: UUID) -> Optional[_RootRunState]:
        run_state = self._get_run_state(run_id)

        if run_state is None:
            return None

        return self._root_run_states.get(run_state.root_run_id)

    def _pop_root_run_resume_key(self, run_id: UUID) -> Optional[str]:
        root_run_state = self._get_root_run_state(run_id)

        if root_run_state is None:
            return None

        resume_key = root_run_state.resume_key
        root_run_state.resume_key = None

        return resume_key

    def _get_parent_run_id(self, run_id: UUID) -> Optional[UUID]:
        run_state = self._get_run_state(run_id)
        return run_state.parent_run_id if run_state is not None else None

    def _is_langgraph_resume(self, inputs: Any) -> bool:
        return (
            LANGGRAPH_COMMAND_TYPE is not None
            and isinstance(inputs, LANGGRAPH_COMMAND_TYPE)
            and getattr(inputs, "resume", None) is not None
        )

    def _store_resume_trace_context(
        self, *, resume_key: str, trace_context: TraceContext
    ) -> None:
        self._pending_resume_trace_contexts.store(
            resume_key=resume_key, trace_context=trace_context
        )

    def _take_root_trace_context(
        self, *, inputs: Any, metadata: Optional[Dict[str, Any]]
    ) -> tuple[Optional[str], Optional[TraceContext]]:
        if self._trace_context is not None:
            return None, self._trace_context

        current_span_context = trace.get_current_span().get_span_context()

        # Only reuse the pending resume context when this callback run has no active
        # parent span of its own. Nested callbacks should attach normally.
        if current_span_context.is_valid:
            return None, None

        # Only explicit LangGraph resumes should consume pending trace linkage.
        if not self._is_langgraph_resume(inputs):
            return None, None

        resume_key = self._get_langgraph_resume_key(metadata)
        if resume_key is None:
            return None, None

        return resume_key, self._pending_resume_trace_contexts.take(resume_key)

    def _restore_root_trace_context(
        self, *, resume_key: Optional[str], trace_context: Optional[TraceContext]
    ) -> None:
        if self._trace_context is not None:
            return

        if resume_key is None or trace_context is None:
            return

        # Span creation failed after we consumed the pending linkage, so put it
        # back and let the next retry resume the interrupted trace correctly.
        self._store_resume_trace_context(
            resume_key=resume_key, trace_context=trace_context
        )

    def _clear_root_run_resume_key(self, run_id: UUID) -> None:
        # Keep the pending interrupt context until an explicit Command(resume=...)
        # arrives. A separate root run on the same thread_id is not a resume.
        self._pop_root_run_resume_key(run_id)

    def _persist_resume_trace_context(self, *, run_id: UUID, observation: Any) -> None:
        if self._trace_context is not None:
            return

        resume_key = self._pop_root_run_resume_key(run_id)
        if resume_key is None:
            return

        self._store_resume_trace_context(
            resume_key=resume_key,
            trace_context={
                "trace_id": observation.trace_id,
                "parent_span_id": observation.id,
            },
        )

    def _get_error_level_and_status_message(
        self, error: BaseException
    ) -> tuple[Literal["DEFAULT", "ERROR"], str]:
        # LangGraph uses GraphBubbleUp subclasses for expected control flow such as
        # interrupts and handoffs, so they should stay visible without being errors.
        if any(isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES):
            return "DEFAULT", str(error) or type(error).__name__

        return "ERROR", str(error)

    def _get_observation_type_from_serialized(
        self, serialized: Optional[Dict[str, Any]], callback_type: str, **kwargs: Any
    ) -> Union[
        Literal["tool"],
        Literal["retriever"],
        Literal["generation"],
        Literal["agent"],
        Literal["chain"],
        Literal["span"],
    ]:
        """Determine Langfuse observation type from LangChain component.

        Args:
            serialized: LangChain's serialized component dict
            callback_type: The type of callback (e.g., "chain", "tool", "retriever", "llm")
            **kwargs: Additional keyword arguments from the callback

        Returns:
            The appropriate Langfuse observation type string
        """
        # Direct mappings based on callback type
        if callback_type == "tool":
            return "tool"
        elif callback_type == "retriever":
            return "retriever"
        elif callback_type == "llm":
            return "generation"
        elif callback_type == "chain":
            # Detect if it's an agent by examining class path or name
            if serialized and "id" in serialized:
                class_path = serialized["id"]
                if any("agent" in part.lower() for part in class_path):
                    return "agent"

            # Check name for agent-related keywords
            name = self.get_langchain_run_name(serialized, **kwargs)
            if "agent" in name.lower():
                return "agent"

            return "chain"

        return "span"

    def get_langchain_run_name(
        self, serialized: Optional[Dict[str, Any]], **kwargs: Any
    ) -> str:
        """Retrieve the name of a serialized LangChain runnable.

        The prioritization for the determination of the run name is as follows:
        - The value assigned to the "name" key in `kwargs`.
        - The value assigned to the "name" key in `serialized`.
        - The last entry of the value assigned to the "id" key in `serialized`.
        - "<unknown>".

        Args:
            serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data.
            **kwargs (Any): Additional keyword arguments, potentially including the 'name' override.

        Returns:
            str: The determined name of the Langchain runnable.
        """
        if "name" in kwargs and kwargs["name"] is not None:
            return str(kwargs["name"])

        if serialized is None:
            return "<unknown>"

        try:
            return str(serialized["name"])
        except (KeyError, TypeError):
            pass

        try:
            return str(serialized["id"][-1])
        except (KeyError, TypeError):
            pass

        return "<unknown>"

    def on_retriever_error(
        self,
        error: BaseException,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when Retriever errors."""
        try:
            self._log_debug_event(
                "on_retriever_error", run_id, parent_run_id, error=error
            )
            observation = self._detach_observation(run_id)

            if observation is not None:
                level, status_message = self._get_error_level_and_status_message(error)
                observation.update(
                    level=cast(
                        Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
                        level,
                    ),
                    status_message=status_message,
                    input=kwargs.get("inputs"),
                    cost_details={"total": 0},
                ).end()

                if parent_run_id is None and level == "DEFAULT":
                    self._persist_resume_trace_context(
                        run_id=run_id, observation=observation
                    )
                elif parent_run_id is None:
                    self._clear_root_run_resume_key(run_id)

        except Exception as e:
            langfuse_logger.exception(e)
        finally:
            if parent_run_id is None:
                self._reset(run_id)

    def _parse_langfuse_trace_attributes(
        self, *, metadata: Optional[Dict[str, Any]], tags: Optional[List[str]]
    ) -> Dict[str, Any]:
        attributes: Dict[str, Any] = {}

        if metadata is None and tags is not None:
            return {"tags": tags}

        if metadata is None:
            return attributes

        if "langfuse_session_id" in metadata and isinstance(
            metadata["langfuse_session_id"], str
        ):
            attributes["session_id"] = metadata["langfuse_session_id"]

        if "langfuse_user_id" in metadata and isinstance(
            metadata["langfuse_user_id"], str
        ):
            attributes["user_id"] = metadata["langfuse_user_id"]

        if "langfuse_trace_name" in metadata and isinstance(
            metadata["langfuse_trace_name"], str
        ):
            attributes["trace_name"] = metadata["langfuse_trace_name"]

        if tags is not None or (
            "langfuse_tags" in metadata and isinstance(metadata["langfuse_tags"], list)
        ):
            langfuse_tags = (
                metadata["langfuse_tags"]
                if "langfuse_tags" in metadata
                and isinstance(metadata["langfuse_tags"], list)
                else []
            )
            merged_tags = list(set(langfuse_tags) | set(tags or []))
            attributes["tags"] = [str(tag) for tag in set(merged_tags)]

        attributes["metadata"] = _strip_langfuse_keys_from_dict(metadata, False)

        return attributes

    def _get_langchain_observation_metadata(
        self,
        *,
        parent_run_id: Optional[UUID],
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        keep_langfuse_trace_attributes: bool = False,
    ) -> Optional[Dict[str, Any]]:
        observation_metadata = self.__join_tags_and_metadata(
            tags=tags,
            metadata=metadata,
            keep_langfuse_trace_attributes=keep_langfuse_trace_attributes,
        )

        if parent_run_id is not None:
            return observation_metadata

        root_metadata = observation_metadata.copy() if observation_metadata else {}
        root_metadata["is_langchain_root"] = True

        return root_metadata

    def on_chain_start(
        self,
        serialized: Optional[Dict[str, Any]],
        inputs: Any,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata)

        span = None
        resume_key = None
        trace_context = None

        try:
            self._log_debug_event(
                "on_chain_start", run_id, parent_run_id, inputs=inputs
            )
            self._register_langfuse_prompt(
                run_id=run_id, parent_run_id=parent_run_id, metadata=metadata
            )

            span_name = self.get_langchain_run_name(serialized, **kwargs)
            span_metadata = self._get_langchain_observation_metadata(
                parent_run_id=parent_run_id,
                tags=tags,
                metadata=metadata,
            )
            span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None

            observation_type = self._get_observation_type_from_serialized(
                serialized, "chain", **kwargs
            )

            # Handle trace attribute propagation at the root of the chain
            if parent_run_id is None:
                parsed_trace_attributes = self._parse_langfuse_trace_attributes(
                    metadata=metadata, tags=tags
                )

                propagation_context_manager = propagate_attributes(
                    user_id=parsed_trace_attributes.get("user_id", None),
                    session_id=parsed_trace_attributes.get("session_id", None),
                    tags=parsed_trace_attributes.get("tags", None),
                    metadata=parsed_trace_attributes.get("metadata", None),
                    trace_name=parsed_trace_attributes.get("trace_name", None),
                )

                root_run_state = self._get_root_run_state(run_id)
                if root_run_state is not None:
                    root_run_state.propagation_context_manager = (
                        propagation_context_manager
                    )

                propagation_context_manager.__enter__()

            obs = self._get_parent_observation(parent_run_id)
            if isinstance(obs, Langfuse):
                resume_key, trace_context = self._take_root_trace_context(
                    inputs=inputs, metadata=metadata
                )
                span = obs.start_observation(
                    trace_context=trace_context,
                    name=span_name,
                    as_type=observation_type,
                    metadata=span_metadata,
                    input=inputs,
                    level=cast(
                        Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"] | None,
                        span_level,
                    ),
                )
            else:
                span = obs.start_observation(
                    name=span_name,
                    as_type=observation_type,
                    metadata=span_metadata,
                    input=inputs,
                    level=cast(
                        Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"] | None,
                        span_level,
                    ),
                )

            self._attach_observation(run_id, span)

            self.last_trace_id = self._runs[run_id].trace_id

        except Exception as e:
            if span is None:
                self._restore_root_trace_context(
                    resume_key=resume_key, trace_context=trace_context
                )
                if parent_run_id is None:
                    self._exit_propagation_context(run_id)
                    self._reset(run_id)
            langfuse_logger.exception(e)

    def _register_langfuse_prompt(
        self,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID],
        metadata: Optional[Dict[str, Any]],
    ) -> None:
        """We need to register any passed Langfuse prompt to the parent_run_id so that we can link following generations with that prompt.

        If parent_run_id is None, we are at the root of a trace and should not attempt to register the prompt, as there will be no LLM invocation following it.
        Otherwise it would have been traced in with a parent run consisting of the prompt template formatting and the LLM invocation.
        """
        if not parent_run_id or not run_id:
            return

        langfuse_prompt = metadata and metadata.get("langfuse_prompt", None)

        if langfuse_prompt:
            self._prompt_to_parent_run_map[parent_run_id] = langfuse_prompt

        # If we have a registered prompt that has not been linked to a generation yet, we need to allow _children_ of that chain to link to it.
        # Otherwise, we only allow generations on the same level of the prompt rendering to be linked, not if they are nested.
        elif parent_run_id in self._prompt_to_parent_run_map:
            registered_prompt = self._prompt_to_parent_run_map[parent_run_id]
            self._prompt_to_parent_run_map[run_id] = registered_prompt

    def _deregister_langfuse_prompt(self, run_id: Optional[UUID]) -> None:
        if run_id is not None and run_id in self._prompt_to_parent_run_map:
            del self._prompt_to_parent_run_map[run_id]

    def _get_parent_observation(
        self, parent_run_id: Optional[UUID]
    ) -> Union[
        Langfuse,
        LangfuseAgent,
        LangfuseChain,
        LangfuseGeneration,
        LangfuseRetriever,
        LangfuseSpan,
        LangfuseTool,
    ]:
        if parent_run_id and parent_run_id in self._runs:
            return self._runs[parent_run_id]

        return self._langfuse_client

    def _attach_observation(
        self,
        run_id: UUID,
        observation: Union[
            LangfuseAgent,
            LangfuseChain,
            LangfuseGeneration,
            LangfuseRetriever,
            LangfuseSpan,
            LangfuseTool,
        ],
    ) -> None:
        ctx = trace.set_span_in_context(observation._otel_span)
        token = context.attach(ctx)

        self._runs[run_id] = observation
        self._context_tokens[run_id] = token

    def _detach_observation(
        self, run_id: UUID
    ) -> Optional[
        Union[
            LangfuseAgent,
            LangfuseChain,
            LangfuseGeneration,
            LangfuseRetriever,
            LangfuseSpan,
            LangfuseTool,
        ]
    ]:
        token = self._context_tokens.pop(run_id, None)

        if token:
            _detach_context_token_safely(token)

        return cast(
            Union[
                LangfuseAgent,
                LangfuseChain,
                LangfuseGeneration,
                LangfuseRetriever,
                LangfuseSpan,
                LangfuseTool,
            ],
            self._runs.pop(run_id, None),
        )

    def on_agent_action(
        self,
        action: AgentAction,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run on agent action."""
        self._track_run(run_id=run_id, parent_run_id=parent_run_id)

        try:
            self._log_debug_event(
                "on_agent_action", run_id, parent_run_id, action=action
            )

            agent_run = self._runs.get(run_id, None)

            if agent_run is not None:
                agent_run._otel_span.set_attribute(
                    LangfuseOtelSpanAttributes.OBSERVATION_TYPE, "agent"
                )

                agent_run.update(
                    output=action,
                    input=kwargs.get("inputs"),
                )

        except Exception as e:
            langfuse_logger.exception(e)

    def on_agent_finish(
        self,
        finish: AgentFinish,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        try:
            self._log_debug_event(
                "on_agent_finish", run_id, parent_run_id, finish=finish
            )
            # Langchain is sending same run ID for both agent finish and chain end
            # handle cleanup of observation in the chain end callback
            agent_run = self._runs.get(run_id, None)

            if agent_run is not None:
                agent_run._otel_span.set_attribute(
                    LangfuseOtelSpanAttributes.OBSERVATION_TYPE, "agent"
                )

                agent_run.update(
                    output=finish,
                    input=kwargs.get("inputs"),
                )

        except Exception as e:
            langfuse_logger.exception(e)

    def on_chain_end(
        self,
        outputs: Dict[str, Any],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        try:
            self._log_debug_event(
                "on_chain_end", run_id, parent_run_id, outputs=outputs
            )

            span = self._detach_observation(run_id)

            if span is not None:
                span.update(
                    output=outputs,
                    input=kwargs.get("inputs"),
                )

                if parent_run_id is None:
                    self._clear_root_run_resume_key(run_id)
                    self._exit_propagation_context(run_id)

                span.end()

                self._deregister_langfuse_prompt(run_id)

        except Exception as e:
            langfuse_logger.exception(e)

        finally:
            if parent_run_id is None:
                self._exit_propagation_context(run_id)
                self._reset(run_id)

    def on_chain_error(
        self,
        error: BaseException,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> None:
        try:
            self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
            level, status_message = self._get_error_level_and_status_message(error)

            observation = self._detach_observation(run_id)

            if observation is not None:
                observation.update(
                    level=cast(
                        Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
                        level,
                    ),
                    status_message=status_message,
                    input=kwargs.get("inputs"),
                    cost_details={"total": 0},
                )

                if parent_run_id is None:
                    if level == "DEFAULT":
                        self._persist_resume_trace_context(
                            run_id=run_id, observation=observation
                        )
                    else:
                        self._clear_root_run_resume_key(run_id)
                    self._exit_propagation_context(run_id)

                observation.end()

        except Exception as e:
            langfuse_logger.exception(e)
        finally:
            if parent_run_id is None:
                self._exit_propagation_context(run_id)
                self._reset(run_id)

    def on_chat_model_start(
        self,
        serialized: Optional[Dict[str, Any]],
        messages: List[List[BaseMessage]],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata)

        try:
            self._log_debug_event(
                "on_chat_model_start", run_id, parent_run_id, messages=messages
            )
            self.__on_llm_action(
                serialized,
                run_id,
                cast(
                    List,
                    _flatten_comprehension(
                        [self._create_message_dicts(m) for m in messages]
                    ),
                ),
                parent_run_id,
                tags=tags,
                metadata=metadata,
                **kwargs,
            )
        except Exception as e:
            langfuse_logger.exception(e)

    def on_llm_start(
        self,
        serialized: Optional[Dict[str, Any]],
        prompts: List[str],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata)

        try:
            self._log_debug_event(
                "on_llm_start", run_id, parent_run_id, prompts=prompts
            )
            self.__on_llm_action(
                serialized,
                run_id,
                cast(List, prompts[0] if len(prompts) == 1 else prompts),
                parent_run_id,
                tags=tags,
                metadata=metadata,
                **kwargs,
            )
        except Exception as e:
            langfuse_logger.exception(e)

    def on_tool_start(
        self,
        serialized: Optional[Dict[str, Any]],
        input_str: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata)

        try:
            self._log_debug_event(
                "on_tool_start", run_id, parent_run_id, input_str=input_str
            )

            meta = self._get_langchain_observation_metadata(
                parent_run_id=parent_run_id,
                tags=tags,
                metadata=metadata,
            )

            if not meta:
                meta = {}

            meta.update(
                {key: value for key, value in kwargs.items() if value is not None}
            )

            observation_type = self._get_observation_type_from_serialized(
                serialized, "tool", **kwargs
            )

            parent_observation = self._get_parent_observation(parent_run_id)
            if isinstance(parent_observation, Langfuse):
                span = parent_observation.start_observation(
                    trace_context=self._trace_context,
                    name=self.get_langchain_run_name(serialized, **kwargs),
                    as_type=observation_type,
                    input=input_str,
                    metadata=meta,
                    level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
                )
            else:
                span = parent_observation.start_observation(
                    name=self.get_langchain_run_name(serialized, **kwargs),
                    as_type=observation_type,
                    input=input_str,
                    metadata=meta,
                    level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
                )

            self._attach_observation(run_id, span)

        except Exception as e:
            langfuse_logger.exception(e)

    def on_retriever_start(
        self,
        serialized: Optional[Dict[str, Any]],
        query: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata)

        try:
            self._log_debug_event(
                "on_retriever_start", run_id, parent_run_id, query=query
            )
            span_name = self.get_langchain_run_name(serialized, **kwargs)
            span_metadata = self._get_langchain_observation_metadata(
                parent_run_id=parent_run_id,
                tags=tags,
                metadata=metadata,
            )
            span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None

            observation_type = self._get_observation_type_from_serialized(
                serialized, "retriever", **kwargs
            )
            parent_observation = self._get_parent_observation(parent_run_id)
            if isinstance(parent_observation, Langfuse):
                span = parent_observation.start_observation(
                    trace_context=self._trace_context,
                    name=span_name,
                    as_type=observation_type,
                    metadata=span_metadata,
                    input=query,
                    level=cast(
                        Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
                        span_level,
                    ),
                )
            else:
                span = parent_observation.start_observation(
                    name=span_name,
                    as_type=observation_type,
                    metadata=span_metadata,
                    input=query,
                    level=cast(
                        Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
                        span_level,
                    ),
                )

            self._attach_observation(run_id, span)

        except Exception as e:
            langfuse_logger.exception(e)

    def on_retriever_end(
        self,
        documents: Sequence[Document],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        try:
            self._log_debug_event(
                "on_retriever_end", run_id, parent_run_id, documents=documents
            )
            observation = self._detach_observation(run_id)

            if observation is not None:
                if parent_run_id is None:
                    self._clear_root_run_resume_key(run_id)
                observation.update(
                    output=documents,
                    input=kwargs.get("inputs"),
                ).end()

        except Exception as e:
            langfuse_logger.exception(e)
        finally:
            if parent_run_id is None:
                self._reset(run_id)

    def on_tool_end(
        self,
        output: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        try:
            self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output)

            observation = self._detach_observation(run_id)

            if observation is not None:
                if parent_run_id is None:
                    self._clear_root_run_resume_key(run_id)
                observation.update(
                    output=output,
                    input=kwargs.get("inputs"),
                ).end()

        except Exception as e:
            langfuse_logger.exception(e)
        finally:
            if parent_run_id is None:
                self._reset(run_id)

    def on_tool_error(
        self,
        error: BaseException,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        try:
            self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error)
            observation = self._detach_observation(run_id)

            if observation is not None:
                level, status_message = self._get_error_level_and_status_message(error)
                observation.update(
                    status_message=status_message,
                    level=cast(
                        Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
                        level,
                    ),
                    input=kwargs.get("inputs"),
                    cost_details={"total": 0},
                ).end()

                if parent_run_id is None and level == "DEFAULT":
                    self._persist_resume_trace_context(
                        run_id=run_id, observation=observation
                    )
                elif parent_run_id is None:
                    self._clear_root_run_resume_key(run_id)

        except Exception as e:
            langfuse_logger.exception(e)
        finally:
            if parent_run_id is None:
                self._reset(run_id)

    def __on_llm_action(
        self,
        serialized: Optional[Dict[str, Any]],
        run_id: UUID,
        prompts: List[Any],
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> None:
        self._track_run(run_id=run_id, parent_run_id=parent_run_id, metadata=metadata)

        try:
            tools = kwargs.get("invocation_params", {}).get("tools", None)
            if tools and isinstance(tools, list):
                prompts.extend([{"role": "tool", "content": tool} for tool in tools])

            model_name = self._parse_model_and_log_errors(
                serialized=serialized, metadata=metadata, kwargs=kwargs
            )

            registered_prompt = None
            current_parent_run_id = parent_run_id

            # Check all parents for registered prompt
            while current_parent_run_id is not None:
                registered_prompt = self._prompt_to_parent_run_map.get(
                    current_parent_run_id
                )

                if registered_prompt:
                    self._deregister_langfuse_prompt(current_parent_run_id)
                    break
                else:
                    current_parent_run_id = self._get_parent_run_id(
                        current_parent_run_id
                    )

            content = {
                "name": self.get_langchain_run_name(serialized, **kwargs),
                "input": prompts,
                "metadata": self._get_langchain_observation_metadata(
                    parent_run_id=parent_run_id,
                    tags=tags,
                    metadata=metadata,
                    # If llm is run isolated and outside chain, keep trace attributes
                    keep_langfuse_trace_attributes=True
                    if parent_run_id is None
                    else False,
                ),
                "model": model_name,
                "model_parameters": self._parse_model_parameters(kwargs),
                "prompt": registered_prompt,
            }

            parent_observation = self._get_parent_observation(parent_run_id)
            if isinstance(parent_observation, Langfuse):
                generation = parent_observation.start_observation(
                    trace_context=self._trace_context,
                    as_type="generation",
                    **content,
                )  # type: ignore
            else:
                generation = parent_observation.start_observation(
                    as_type="generation", **content
                )  # type: ignore
            self._attach_observation(run_id, generation)

            self.last_trace_id = self._runs[run_id].trace_id

        except Exception as e:
            langfuse_logger.exception(e)

    @staticmethod
    def _parse_model_parameters(kwargs: Dict[str, Any]) -> Dict[str, Any]:
        """Parse the model parameters from the kwargs."""
        if kwargs["invocation_params"].get("_type") == "IBM watsonx.ai" and kwargs[
            "invocation_params"
        ].get("params"):
            kwargs["invocation_params"] = {
                **kwargs["invocation_params"],
                **kwargs["invocation_params"]["params"],
            }
            del kwargs["invocation_params"]["params"]
        return {
            key: value
            for key, value in {
                "temperature": kwargs["invocation_params"].get("temperature"),
                "max_tokens": kwargs["invocation_params"].get("max_tokens"),
                "max_completion_tokens": kwargs["invocation_params"].get(
                    "max_completion_tokens"
                ),
                "top_p": kwargs["invocation_params"].get("top_p"),
                "frequency_penalty": kwargs["invocation_params"].get(
                    "frequency_penalty"
                ),
                "presence_penalty": kwargs["invocation_params"].get("presence_penalty"),
                "request_timeout": kwargs["invocation_params"].get("request_timeout"),
                "decoding_method": kwargs["invocation_params"].get("decoding_method"),
                "min_new_tokens": kwargs["invocation_params"].get("min_new_tokens"),
                "max_new_tokens": kwargs["invocation_params"].get("max_new_tokens"),
                "stop_sequences": kwargs["invocation_params"].get("stop_sequences"),
            }.items()
            if value is not None
        }

    def _parse_model_and_log_errors(
        self,
        *,
        serialized: Optional[Dict[str, Any]],
        metadata: Optional[Dict[str, Any]],
        kwargs: Dict[str, Any],
    ) -> Optional[str]:
        """Parse the model name and log errors if parsing fails."""
        try:
            model_name = _parse_model_name_from_metadata(
                metadata
            ) or _extract_model_name(serialized, **kwargs)

            if model_name:
                return model_name

        except Exception as e:
            langfuse_logger.exception(e)

        self._log_model_parse_warning()
        return None

    def _log_model_parse_warning(self) -> None:
        if not hasattr(self, "_model_parse_warning_logged"):
            langfuse_logger.warning(
                "Langfuse was not able to parse the LLM model. The LLM call will be recorded without model name. Please create an issue: https://github.com/langfuse/langfuse/issues/new/choose"
            )

            self._model_parse_warning_logged = True

    def on_llm_end(
        self,
        response: LLMResult,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        try:
            self._log_debug_event(
                "on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs
            )
            response_generation = response.generations[-1][-1]
            extracted_response = (
                self._convert_message_to_dict(response_generation.message)
                if isinstance(response_generation, ChatGeneration)
                else _extract_raw_response(response_generation)
            )

            llm_usage = _parse_usage(response)

            # e.g. azure returns the model name in the response
            model = _parse_model(response)

            generation = self._detach_observation(run_id)

            if generation is not None:
                generation.update(
                    output=extracted_response,
                    usage=llm_usage,
                    usage_details=llm_usage,
                    input=kwargs.get("inputs"),
                    model=model,
                ).end()

        except Exception as e:
            langfuse_logger.exception(e)

        finally:
            self._updated_completion_start_time_memo.discard(run_id)

            if parent_run_id is None:
                self._clear_root_run_resume_key(run_id)
                self._reset(run_id)

    def on_llm_error(
        self,
        error: BaseException,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        try:
            self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error)

            generation = self._detach_observation(run_id)

            if generation is not None:
                level, status_message = self._get_error_level_and_status_message(error)
                generation.update(
                    status_message=status_message,
                    level=cast(
                        Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]],
                        level,
                    ),
                    input=kwargs.get("inputs"),
                    cost_details={"total": 0},
                ).end()

                if parent_run_id is None and level == "DEFAULT":
                    self._persist_resume_trace_context(
                        run_id=run_id, observation=generation
                    )
                elif parent_run_id is None:
                    self._clear_root_run_resume_key(run_id)

        except Exception as e:
            langfuse_logger.exception(e)
        finally:
            if parent_run_id is None:
                self._reset(run_id)

    def _reset(self, root_run_id: UUID) -> None:
        run_state = self._get_run_state(root_run_id)
        if run_state is None:
            return

        root_run_state = self._root_run_states.pop(run_state.root_run_id, None)
        if root_run_state is None:
            self._run_states.pop(root_run_id, None)
            return

        for run_id in root_run_state.run_ids:
            self._run_states.pop(run_id, None)

    def _exit_propagation_context(self, run_id: UUID) -> None:
        root_run_state = self._get_root_run_state(run_id)

        if root_run_state is None:
            return

        manager = root_run_state.propagation_context_manager
        if manager is None:
            return

        root_run_state.propagation_context_manager = None
        manager.__exit__(None, None, None)

    def __join_tags_and_metadata(
        self,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        keep_langfuse_trace_attributes: bool = False,
    ) -> Optional[Dict[str, Any]]:
        final_dict = {}
        if tags is not None and len(tags) > 0:
            final_dict["tags"] = tags
        if metadata is not None:
            final_dict.update(metadata)

        return (
            _strip_langfuse_keys_from_dict(final_dict, keep_langfuse_trace_attributes)
            if final_dict != {}
            else None
        )

    def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
        # assistant message
        if isinstance(message, HumanMessage):
            message_dict: Dict[str, Any] = {"role": "user", "content": message.content}
        elif isinstance(message, AIMessage):
            message_dict = {"role": "assistant", "content": message.content}

            if (
                hasattr(message, "tool_calls")
                and message.tool_calls is not None
                and len(message.tool_calls) > 0
            ):
                message_dict["tool_calls"] = message.tool_calls

            if (
                hasattr(message, "invalid_tool_calls")
                and message.invalid_tool_calls is not None
                and len(message.invalid_tool_calls) > 0
            ):
                message_dict["invalid_tool_calls"] = message.invalid_tool_calls

        elif isinstance(message, SystemMessage):
            message_dict = {"role": "system", "content": message.content}
        elif isinstance(message, ToolMessage):
            message_dict = {
                "role": "tool",
                "content": message.content,
                "tool_call_id": message.tool_call_id,
            }
        elif isinstance(message, FunctionMessage):
            message_dict = {"role": "function", "content": message.content}
        elif isinstance(message, ChatMessage):
            message_dict = {"role": message.role, "content": message.content}
        else:
            raise ValueError(f"Got unknown type {message}")
        if "name" in message.additional_kwargs:
            message_dict["name"] = message.additional_kwargs["name"]

        if message.additional_kwargs:
            message_dict["additional_kwargs"] = message.additional_kwargs  # type: ignore

        return message_dict

    def _create_message_dicts(
        self, messages: List[BaseMessage]
    ) -> List[Dict[str, Any]]:
        return [self._convert_message_to_dict(m) for m in messages]

    def _log_debug_event(
        self,
        event_name: str,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> None:
        langfuse_logger.debug(
            f"Event: {event_name}, run_id: {run_id}, parent_run_id: {parent_run_id}"
        )


def _extract_raw_response(last_response: Any) -> Any:
    """Extract the response from the last response of the LLM call."""
    # We return the text of the response if not empty
    if last_response.text is not None and last_response.text.strip() != "":
        return last_response.text.strip()
    elif hasattr(last_response, "message"):
        # Additional kwargs contains the response in case of tool usage
        return last_response.message.additional_kwargs
    else:
        # Not tool usage, some LLM responses can be simply empty
        return ""


def _flatten_comprehension(matrix: Any) -> Any:
    return [item for row in matrix for item in row]


def _parse_usage_model(usage: Union[pydantic.BaseModel, dict]) -> Any:
    # maintains a list of key translations. For each key, the usage model is checked
    # and a new object will be created with the new key if the key exists in the usage model
    # All non matched keys will remain on the object.

    if hasattr(usage, "__dict__"):
        usage = usage.__dict__

    conversion_list = [
        # https://pypi.org/project/langchain-anthropic/ (works also for Bedrock-Anthropic)
        ("input_tokens", "input"),
        ("output_tokens", "output"),
        ("total_tokens", "total"),
        # ChatBedrock API follows a separate format compared to ChatBedrockConverse API
        ("prompt_tokens", "input"),
        ("completion_tokens", "output"),
        # https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/get-token-count
        ("prompt_token_count", "input"),
        ("candidates_token_count", "output"),
        ("total_token_count", "total"),
        # Bedrock: https://docs.aws.amazon.com/bedrock/latest/userguide/monitoring-cw.html#runtime-cloudwatch-metrics
        ("inputTokenCount", "input"),
        ("outputTokenCount", "output"),
        ("totalTokenCount", "total"),
        # langchain-ibm https://pypi.org/project/langchain-ibm/
        ("input_token_count", "input"),
        ("generated_token_count", "output"),
    ]

    usage_model = cast(Dict, usage.copy())  # Copy all existing key-value pairs

    # Skip OpenAI usage types as they are handled server side
    if (
        all(
            openai_key in usage_model
            for openai_key in [
                "prompt_tokens",
                "completion_tokens",
                "total_tokens",
                "prompt_tokens_details",
                "completion_tokens_details",
            ]
        )
        and len(usage_model.keys()) == 5
    ) or (
        all(
            openai_key in usage_model
            for openai_key in [
                "prompt_tokens",
                "completion_tokens",
                "total_tokens",
            ]
        )
        and len(usage_model.keys()) == 3
    ):
        return usage_model

    for model_key, langfuse_key in conversion_list:
        if model_key in usage_model:
            captured_count = usage_model.pop(model_key)
            final_count = (
                sum(captured_count)
                if isinstance(captured_count, list)
                else captured_count
            )  # For Bedrock, the token count is a list when streamed

            usage_model[langfuse_key] = final_count  # Translate key and keep the value

    if isinstance(usage_model, dict):
        if "input_token_details" in usage_model:
            input_token_details = usage_model.pop("input_token_details", {})

            for key, value in input_token_details.items():
                usage_model[f"input_{key}"] = value

                # Skip priority-tier keys as they are not exclusive sub-categories
                if key == "priority" or key.startswith("priority_"):
                    continue

                if "input" in usage_model:
                    usage_model["input"] = max(0, usage_model["input"] - value)

        if "output_token_details" in usage_model:
            output_token_details = usage_model.pop("output_token_details", {})

            for key, value in output_token_details.items():
                usage_model[f"output_{key}"] = value

                # Skip priority-tier keys as they are not exclusive sub-categories
                if key == "priority" or key.startswith("priority_"):
                    continue

                if "output" in usage_model:
                    usage_model["output"] = max(0, usage_model["output"] - value)

        # Vertex AI
        if "prompt_tokens_details" in usage_model and isinstance(
            usage_model["prompt_tokens_details"], list
        ):
            prompt_tokens_details = usage_model.pop("prompt_tokens_details")

            for item in prompt_tokens_details:
                if (
                    isinstance(item, dict)
                    and "modality" in item
                    and "token_count" in item
                ):
                    value = item["token_count"]
                    usage_model[f"input_modality_{item['modality']}"] = value

                    if "input" in usage_model:
                        usage_model["input"] = max(0, usage_model["input"] - value)

        # Vertex AI
        if "candidates_tokens_details" in usage_model and isinstance(
            usage_model["candidates_tokens_details"], list
        ):
            candidates_tokens_details = usage_model.pop("candidates_tokens_details")

            for item in candidates_tokens_details:
                if (
                    isinstance(item, dict)
                    and "modality" in item
                    and "token_count" in item
                ):
                    value = item["token_count"]
                    usage_model[f"output_modality_{item['modality']}"] = value

                    if "output" in usage_model:
                        usage_model["output"] = max(0, usage_model["output"] - value)

        # Vertex AI
        if "cache_tokens_details" in usage_model and isinstance(
            usage_model["cache_tokens_details"], list
        ):
            cache_tokens_details = usage_model.pop("cache_tokens_details")

            for item in cache_tokens_details:
                if (
                    isinstance(item, dict)
                    and "modality" in item
                    and "token_count" in item
                ):
                    value = item["token_count"]
                    usage_model[f"cached_modality_{item['modality']}"] = value

                    if "input" in usage_model:
                        usage_model["input"] = max(0, usage_model["input"] - value)

                    if f"input_modality_{item['modality']}" in usage_model:
                        usage_model[f"input_modality_{item['modality']}"] = max(
                            0, usage_model[f"input_modality_{item['modality']}"] - value
                        )

    usage_model = {k: v for k, v in usage_model.items() if isinstance(v, int)}

    return usage_model if usage_model else None


def _parse_usage(response: LLMResult) -> Any:
    # langchain-anthropic uses the usage field
    llm_usage_keys = ["token_usage", "usage"]
    llm_usage = None
    if response.llm_output is not None:
        for key in llm_usage_keys:
            if key in response.llm_output and response.llm_output[key]:
                llm_usage = _parse_usage_model(response.llm_output[key])
                break

    if hasattr(response, "generations"):
        for generation in response.generations:
            for generation_chunk in generation:
                if generation_chunk.generation_info and (
                    "usage_metadata" in generation_chunk.generation_info
                ):
                    llm_usage = _parse_usage_model(
                        generation_chunk.generation_info["usage_metadata"]
                    )

                    if llm_usage is not None:
                        break

                message_chunk = getattr(generation_chunk, "message", {})
                response_metadata = getattr(message_chunk, "response_metadata", {})

                chunk_usage = (
                    (
                        response_metadata.get("usage", None)  # for Bedrock-Anthropic
                        if isinstance(response_metadata, dict)
                        else None
                    )
                    or (
                        response_metadata.get(
                            "amazon-bedrock-invocationMetrics", None
                        )  # for Bedrock-Titan
                        if isinstance(response_metadata, dict)
                        else None
                    )
                    or getattr(message_chunk, "usage_metadata", None)  # for Ollama
                )

                if chunk_usage:
                    llm_usage = _parse_usage_model(chunk_usage)
                    break

    return llm_usage


def _parse_model(response: LLMResult) -> Any:
    # langchain-anthropic uses the usage field
    llm_model_keys = ["model_name"]
    llm_model = None
    if response.llm_output is not None:
        for key in llm_model_keys:
            if key in response.llm_output and response.llm_output[key]:
                llm_model = response.llm_output[key]
                break

    return llm_model


def _parse_model_name_from_metadata(metadata: Optional[Dict[str, Any]]) -> Any:
    if metadata is None or not isinstance(metadata, dict):
        return None

    return metadata.get("ls_model_name", None)


def _strip_langfuse_keys_from_dict(
    metadata: Optional[Dict[str, Any]], keep_langfuse_trace_attributes: bool
) -> Any:
    if metadata is None or not isinstance(metadata, dict):
        return metadata

    langfuse_metadata_keys = [
        "langfuse_prompt",
    ]

    langfuse_trace_attribute_keys = [
        "langfuse_session_id",
        "langfuse_user_id",
        "langfuse_tags",
        "langfuse_trace_name",
    ]

    metadata_copy = metadata.copy()

    for key in langfuse_metadata_keys:
        metadata_copy.pop(key, None)

    if not keep_langfuse_trace_attributes:
        for key in langfuse_trace_attribute_keys:
            metadata_copy.pop(key, None)

    return metadata_copy
