from contextlib import contextmanager

from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import Pool

from app.config import settings
from common_logging import get_logger

logger = get_logger(__name__)
engine = create_engine(
    settings.DATABASE_URL,
    pool_size=20,
    max_overflow=30,
    pool_pre_ping=True,
    pool_recycle=3600,
    echo=False,
    connect_args={"client_encoding": "utf8", "options": "-c client_encoding=utf8"},
)


@event.listens_for(Pool, "connect")
def set_search_path(dbapi_conn, connection_record):
    if hasattr(dbapi_conn, "driver"):
        driver_name = (
            dbapi_conn.driver.lower()
            if isinstance(dbapi_conn.driver, str)
            else str(dbapi_conn.driver).lower()
        )
        if "sqlite" in driver_name:
            logger.debug("Skipping search_path for SQLite connection")
            return
    conn_class = dbapi_conn.__class__.__name__.lower()
    if "sqlite" in conn_class:
        logger.debug("Skipping search_path for SQLite connection (detected via class name)")
        return
    try:
        cursor = dbapi_conn.cursor()
        cursor.execute("SET search_path TO public")
        cursor.close()
        logger.debug("Set search_path to public for PostgreSQL connection")
    except Exception as e:
        logger.warning(f"Failed to set search_path: {e}")


SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
_dc_engine = None


def _get_dc_engine():
    global _dc_engine
    if _dc_engine is None:
        dc_url = getattr(settings, "DATA_CENTER_DATABASE_URL", None)
        if dc_url:
            _dc_engine = create_engine(dc_url, pool_size=5, max_overflow=5, pool_pre_ping=True)
    return _dc_engine


@contextmanager
def get_data_center_db():
    engine = _get_dc_engine()
    if not engine:
        yield None
        return
    session = sessionmaker(bind=engine)()
    try:
        yield session
    finally:
        session.close()


def get_db():
    from app.core.tenant_context import get_current_tenant_id
    from app.db.tenant_schema import TenantSchemaManager

    db = SessionLocal()
    try:
        tenant_id = get_current_tenant_id()
        if tenant_id:
            TenantSchemaManager.set_search_path(db, tenant_id)
            logger.debug(f"Database session created for tenant_id={tenant_id}")
        else:
            TenantSchemaManager.set_search_path(db, None)
            logger.debug("Database session created for public schema")
        yield db
    finally:
        try:
            TenantSchemaManager.set_search_path(db, None)
        except Exception as e:
            logger.warning(f"Error resetting search_path: {e}")
        db.close()


def get_db_without_tenant():
    db = SessionLocal()
    try:
        from app.db.tenant_schema import TenantSchemaManager

        TenantSchemaManager.set_search_path(db, None)
        yield db
    finally:
        db.close()


@contextmanager
def get_db_with_transaction():
    from app.core.exceptions import DatabaseError
    from app.core.tenant_context import get_current_tenant_id
    from app.db.tenant_schema import TenantSchemaManager

    db = SessionLocal()
    try:
        tenant_id = get_current_tenant_id()
        if tenant_id:
            TenantSchemaManager.set_search_path(db, tenant_id)
        else:
            TenantSchemaManager.set_search_path(db, None)
        yield db
        db.commit()
    except Exception as e:
        db.rollback()
        logger.error(f"Transaction failed, rolling back: {str(e)}")
        raise DatabaseError(f"Database transaction failed: {str(e)}") from None
    finally:
        try:
            TenantSchemaManager.set_search_path(db, None)
        except Exception as e:
            logger.warning(f"Error resetting search_path: {e}")
        db.close()
