import uuid
from datetime import datetime, timedelta

from common_logging import get_logger

logger = get_logger(__name__)


class ConversationManager:

    def __init__(self):
        self.conversations: dict[str, dict] = {}
        self.expiration_hours = 24

    def create_conversation(self) -> str:
        conversation_id = str(uuid.uuid4())
        self.conversations[conversation_id] = {
            "messages": [],
            "created_at": datetime.utcnow(),
            "updated_at": datetime.utcnow(),
        }
        logger.bind(conversation_id=conversation_id).info("conversation created")
        return conversation_id

    def add_message(self, conversation_id: str, role: str, content: str) -> None:
        if conversation_id not in self.conversations:
            self.conversations[conversation_id] = {
                "messages": [],
                "created_at": datetime.utcnow(),
                "updated_at": datetime.utcnow(),
            }
        self.conversations[conversation_id]["messages"].append({"role": role, "content": content})
        self.conversations[conversation_id]["updated_at"] = datetime.utcnow()

    def get_messages(
        self, conversation_id: str, max_messages: int | None = None
    ) -> list[dict[str, str]]:
        if conversation_id not in self.conversations:
            return []
        messages = self.conversations[conversation_id]["messages"]
        if max_messages:
            return messages[-max_messages:]
        return messages

    def clear_conversation(self, conversation_id: str) -> bool:
        if conversation_id in self.conversations:
            del self.conversations[conversation_id]
            logger.bind(conversation_id=conversation_id).info("conversation cleared")
            return True
        return False

    def cleanup_expired_conversations(self) -> int:
        now = datetime.utcnow()
        expired_ids = []
        for conv_id, conv_data in self.conversations.items():
            updated_at = conv_data["updated_at"]
            if now - updated_at > timedelta(hours=self.expiration_hours):
                expired_ids.append(conv_id)
        for conv_id in expired_ids:
            del self.conversations[conv_id]
        return len(expired_ids)

    def get_conversation_count(self) -> int:
        return len(self.conversations)

    def conversation_exists(self, conversation_id: str) -> bool:
        return conversation_id in self.conversations


conversation_manager = ConversationManager()
