import json
from pathlib import Path
from typing import Any

from pymilvus import Collection, connections
from sqlalchemy.orm import Session

from app.config import settings

from common_logging import get_logger

logger = get_logger(__name__)




class VectorExportService:

    def __init__(self, db: Session):
        self.db = db
        self.milvus_available = self._connect_milvus()

    def _connect_milvus(self) -> bool:
        try:
            connections.connect(
                alias="default", host=settings.MILVUS_HOST, port=settings.MILVUS_PORT
            )
            return True
        except Exception as e:
            logger.warning(f"Milvus not available: {e}")
            return False

    def export_tenant_vectors(self, tenant_id: int, output_dir: Path) -> dict[str, Any]:
        output_dir.mkdir(parents=True, exist_ok=True)
        if not self.milvus_available:
            logger.warning("Milvus not available, skipping vector export")
            return {"status": "skipped", "reason": "Milvus not available"}
        kb_ids = self._get_tenant_knowledge_bases(tenant_id)
        if not kb_ids:
            return {"stats": {"knowledge_bases": 0, "vectors": 0}}
        collection = Collection("document_vectors")
        collection.load()
        vectors_data = []
        for kb_id in kb_ids:
            expr = f"knowledge_base_id == {kb_id}"
            results = collection.query(expr=expr, output_fields=["*"])
            vectors_data.extend(results)
        vectors_file = output_dir / "vectors.json"
        with open(vectors_file, "w", encoding="utf-8") as f:
            json.dump(vectors_data, f, indent=2, ensure_ascii=False)
        return {
            "vectors_file": str(vectors_file),
            "stats": {"knowledge_bases": len(kb_ids), "vectors": len(vectors_data)},
        }

    def _get_tenant_knowledge_bases(self, tenant_id: int) -> list[int]:
        from sqlalchemy import text


        schema_name = f"tenant_{tenant_id}"
        query = text(f"SELECT id FROM {schema_name}.knowledge_bases WHERE is_deleted = false")
        result = self.db.execute(query)
        return [row[0] for row in result]
