from sqlalchemy import text
from sqlalchemy.orm import Session

from common_logging import get_logger

logger = get_logger(__name__)


class TenantSchemaManager:
    PUBLIC_TABLES = [
        "tenants",
        "users",
        "alembic_version",
        "roles",
        "menus",
        "user_roles",
        "role_menus",
        "role_permissions",
        "casbin_rule",
        "audit_logs",
        "permission_audit_logs",
        "model_providers",
        "models",
    ]
    TENANT_TABLES = [
        "agents",
        "chat_messages",
        "knowledge_bases",
        "knowledge_categories",
        "knowledge_documents",
        "knowledge_tags",
        "document_tags",
        "document_vectors",
        "document_versions",
        "knowledge_qa",
        "knowledge_metadata_fields",
        "document_metadata_values",
        "data_models",
        "data_model_fields",
        "tag_categories",
        "tag_auto_rules",
        "tax_documents",
    ]

    @staticmethod
    def get_schema_name(tenant_id: int) -> str:
        return f"tenant_{tenant_id}"

    @staticmethod
    def schema_exists(db: Session, schema_name: str) -> bool:
        try:
            result = db.execute(
                text(
                    "SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema"
                ),
                {"schema": schema_name},
            )
            return result.fetchone() is not None
        except Exception as e:
            logger.error(f"Error checking schema existence: {e}")
            return False

    @staticmethod
    def create_schema(db: Session, tenant_id: int) -> bool:
        schema_name = TenantSchemaManager.get_schema_name(tenant_id)
        try:
            if TenantSchemaManager.schema_exists(db, schema_name):
                logger.info(f"Schema {schema_name} already exists")
                return True
            db.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
            db.commit()
            logger.info(f"Created schema: {schema_name}")
            return True
        except Exception as e:
            logger.error(f"Error creating schema {schema_name}: {e}")
            db.rollback()
            return False

    @staticmethod
    def delete_schema(db: Session, tenant_id: int, cascade: bool = True) -> bool:
        schema_name = TenantSchemaManager.get_schema_name(tenant_id)
        try:
            if not TenantSchemaManager.schema_exists(db, schema_name):
                logger.warning(f"Schema {schema_name} does not exist")
                return True
            cascade_clause = "CASCADE" if cascade else "RESTRICT"
            db.execute(text(f'DROP SCHEMA IF EXISTS "{schema_name}" {cascade_clause}'))
            db.commit()
            logger.info(f"Deleted schema: {schema_name}")
            return True
        except Exception as e:
            logger.error(f"Error deleting schema {schema_name}: {e}")
            db.rollback()
            return False

    @staticmethod
    def set_search_path(db: Session, tenant_id: int | None = None):
        try:
            if tenant_id:
                schema_name = TenantSchemaManager.get_schema_name(tenant_id)
                db.execute(text(f'SET search_path TO "{schema_name}", public'))
                logger.debug(f"Set search_path to: {schema_name}, public")
            else:
                db.execute(text("SET search_path TO public"))
                logger.debug("Set search_path to: public")
        except Exception as e:
            logger.error(f"Error setting search_path: {e}")
            raise

    @staticmethod
    def get_tenant_table_definitions():
        from app.db.base import Base

        tenant_tables = [
            table
            for table in Base.metadata.tables.values()
            if table.name in TenantSchemaManager.TENANT_TABLES and table.schema is None
        ]
        registered_table_names = {table.name for table in tenant_tables}
        missing_table_names = sorted(
            set(TenantSchemaManager.TENANT_TABLES) - registered_table_names
        )
        if missing_table_names:
            logger.debug(
                "Skipped tenant table definitions not registered as schema-less tables: %s",
                ", ".join(missing_table_names),
            )
        return tenant_tables

    @staticmethod
    def create_tenant_tables(db: Session, tenant_id: int) -> bool:
        from app.db.base import Base

        schema_name = TenantSchemaManager.get_schema_name(tenant_id)
        tenant_tables = TenantSchemaManager.get_tenant_table_definitions()
        if not tenant_tables:
            logger.error("No tenant table definitions found for schema provisioning")
            return False
        try:
            tenant_bind = db.connection().execution_options(
                schema_translate_map={None: schema_name}
            )
            Base.metadata.create_all(bind=tenant_bind, tables=tenant_tables, checkfirst=True)
            db.commit()
            logger.info(
                "Ensured %s tenant tables exist in schema %s", len(tenant_tables), schema_name
            )
            return True
        except Exception as e:
            logger.error(f"Error creating tenant tables in schema {schema_name}: {e}")
            db.rollback()
            return False

    @staticmethod
    def clone_schema_structure(db: Session, source_tenant_id: int, target_tenant_id: int) -> bool:
        source_schema = TenantSchemaManager.get_schema_name(source_tenant_id)
        target_schema = TenantSchemaManager.get_schema_name(target_tenant_id)
        try:
            if not TenantSchemaManager.create_schema(db, target_tenant_id):
                return False
            result = db.execute(
                text(
                    "\n                    SELECT table_name\n                    FROM information_schema.tables\n                    WHERE table_schema = :schema\n                    AND table_type = 'BASE TABLE'\n                "
                ),
                {"schema": source_schema},
            )
            tables = [row[0] for row in result]
            for table in tables:
                db.execute(
                    text(
                        f'\n                    CREATE TABLE "{target_schema}"."{table}"\n                    (LIKE "{source_schema}"."{table}" INCLUDING ALL)\n                '
                    )
                )
            db.commit()
            logger.info(f"Cloned schema structure from {source_schema} to {target_schema}")
            return True
        except Exception as e:
            logger.error(f"Error cloning schema structure: {e}")
            db.rollback()
            return False

    @staticmethod
    def list_tenant_schemas(db: Session) -> list[dict]:
        try:
            result = db.execute(
                text(
                    "\n                    SELECT schema_name\n                    FROM information_schema.schemata\n                    WHERE schema_name LIKE 'tenant_%'\n                    ORDER BY schema_name\n                "
                )
            )
            schemas = []
            for row in result:
                schema_name = row[0]
                tenant_id = int(schema_name.replace("tenant_", ""))
                schemas.append({"tenant_id": tenant_id, "schema_name": schema_name})
            return schemas
        except Exception as e:
            logger.error(f"Error listing tenant schemas: {e}")
            return []

    @staticmethod
    def get_schema_size(db: Session, tenant_id: int) -> int | None:
        schema_name = TenantSchemaManager.get_schema_name(tenant_id)
        try:
            result = db.execute(
                text(
                    "\n                    SELECT SUM(pg_total_relation_size(quote_ident(schemaname) || '.' || quote_ident(tablename)))\n                    FROM pg_tables\n                    WHERE schemaname = :schema\n                "
                ),
                {"schema": schema_name},
            )
            size = result.scalar()
            return size if size else 0
        except Exception as e:
            logger.error(f"Error getting schema size: {e}")
            return None


def provision_tenant_schema(db: Session, tenant_id: int) -> bool:
    try:
        if not TenantSchemaManager.create_schema(db, tenant_id):
            return False
        if not TenantSchemaManager.create_tenant_tables(db, tenant_id):
            return False
        logger.info(f"Provisioned tenant schema for tenant_id={tenant_id}")
        return True
    except Exception as e:
        logger.error(f"Error provisioning tenant schema: {e}")
        db.rollback()
        return False


def deprovision_tenant_schema(db: Session, tenant_id: int) -> bool:
    return TenantSchemaManager.delete_schema(db, tenant_id, cascade=True)
