from threading import Lock
from typing import Optional

from pymilvus import Collection, connections

from common_logging import get_logger

logger = get_logger(__name__)



class MilvusConnectionManager:
    _instance: Optional["MilvusConnectionManager"] = None
    _lock = Lock()

    def __new__(cls):
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super().__new__(cls)
        return cls._instance

    def __init__(self):
        if not hasattr(self, "initialized"):
            self.initialized = True
            self.connection_alias = "default"
            self.connected = False
            self.host = None
            self.port = None
            self.use_lite = False
            logger.info("Milvus connection manager initialized")

    def connect(self, host: str = "localhost", port: str = "19530", use_lite: bool = False):
        if self.connected:
            if self.host == host and self.port == port and (self.use_lite == use_lite):
                logger.debug("Using existing Milvus connection")
                return
            else:
                self.disconnect()
        try:
            if use_lite:
                connections.connect(alias=self.connection_alias, uri="./data/milvus_data.db")
                logger.info("Connected to Milvus Lite")
            else:
                connections.connect(alias=self.connection_alias, host=host, port=port)
                logger.info(f"Connected to Milvus at {host}:{port}")
            self.connected = True
            self.host = host
            self.port = port
            self.use_lite = use_lite
        except Exception as e:
            logger.error(f"Failed to connect to Milvus: {e}")
            self.connected = False
            raise

    def disconnect(self):
        if self.connected:
            try:
                connections.disconnect(alias=self.connection_alias)
                self.connected = False
                logger.info("Disconnected from Milvus")
            except Exception as e:
                logger.error(f"Error disconnecting from Milvus: {e}")

    def get_collection(self, collection_name: str) -> Collection:
        if not self.connected:
            raise Exception("Not connected to Milvus. Call connect() first.")
        try:
            collection = Collection(collection_name)
            try:
                collection.load()
            except Exception as load_error:
                logger.debug(f"Collection {collection_name} load attempt: {load_error}")
            return collection
        except Exception as e:
            logger.error(f"Failed to get collection {collection_name}: {e}")
            raise

    def health_check(self) -> bool:
        if not self.connected:
            return False
        try:
            connections.list_connections()
            return True
        except Exception as e:
            logger.error(f"Milvus health check failed: {e}")
            self.connected = False
            return False

    def reconnect(self):
        if not self.health_check():
            logger.info("Attempting to reconnect to Milvus...")
            self.connected = False
            if self.host and self.port:
                self.connect(self.host, self.port, self.use_lite)


_milvus_manager: MilvusConnectionManager | None = None


def get_milvus_manager() -> MilvusConnectionManager:
    global _milvus_manager
    if _milvus_manager is None:
        _milvus_manager = MilvusConnectionManager()
    return _milvus_manager


def init_milvus_connection(host: str = "localhost", port: str = "19530", use_lite: bool = False):
    manager = get_milvus_manager()
    manager.connect(host, port, use_lite)
