from pathlib import Path
from typing import Any

from sqlalchemy.orm import Session

from common_logging import get_logger

logger = get_logger(__name__)


class TenantValidationService:

    def __init__(self, db: Session):
        self.db = db

    def validate_export(self, tenant_id: int, export_dir: Path) -> dict[str, Any]:
        results = {"valid": True, "errors": [], "warnings": []}
        required_files = ["manifest.json", "postgresql", "neo4j", "milvus", "files"]
        for item in required_files:
            path = export_dir / item
            if not path.exists():
                results["errors"].append(f"Missing: {item}")
                results["valid"] = False
        pg_stats = self._validate_postgresql(tenant_id, export_dir)
        results["postgresql"] = pg_stats
        return results

    def _validate_postgresql(self, tenant_id: int, export_dir: Path) -> dict[str, Any]:
        from sqlalchemy import text

        schema_name = f"tenant_{tenant_id}"
        result = self.db.execute(
            text(
                f"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = '{schema_name}'"
            )
        )
        table_count = result.scalar() or 0
        result = self.db.execute(
            text(f"SELECT COUNT(*) FROM public.users WHERE tenant_id = {tenant_id}")
        )
        user_count = result.scalar() or 0
        return {"tables": table_count, "users": user_count}

    def compare_tenants(self, tenant_id_1: int, tenant_id_2: int) -> dict[str, Any]:
        stats_1 = self._get_tenant_stats(tenant_id_1)
        stats_2 = self._get_tenant_stats(tenant_id_2)
        return {"tenant_1": stats_1, "tenant_2": stats_2, "match": stats_1 == stats_2}

    def _get_tenant_stats(self, tenant_id: int) -> dict[str, int]:
        from sqlalchemy import text

        schema_name = f"tenant_{tenant_id}"
        stats = {}
        result = self.db.execute(
            text(f"SELECT COUNT(*) FROM public.users WHERE tenant_id = {tenant_id}")
        )
        stats["users"] = result.scalar() or 0
        try:
            result = self.db.execute(
                text(f"SELECT COUNT(*) FROM {schema_name}.knowledge_bases WHERE is_deleted = false")
            )
            stats["knowledge_bases"] = result.scalar() or 0
        except Exception:
            stats["knowledge_bases"] = 0
        try:
            result = self.db.execute(
                text(
                    f"SELECT COUNT(*) FROM {schema_name}.knowledge_documents WHERE is_deleted = false"
                )
            )
            stats["documents"] = result.scalar() or 0
        except Exception:
            stats["documents"] = 0
        return stats
