import json
import uuid
from typing import Any

from langchain_core.embeddings import Embeddings
from pymilvus import (

    Collection,
    CollectionSchema,
    DataType,
    FieldSchema,
    Partition,
    connections,
    utility,
)
from sqlalchemy.orm import Session

from app.core.exceptions import VectorStoreConnectionError
from app.models import KnowledgeDocument
from app.services.llm.backends.embedding_backend_factory import get_embedding_factory

from common_logging import get_logger

logger = get_logger(__name__)



class CustomEmbeddings(Embeddings):

    def __init__(self, db: Session, model_id: int | None = None):
        self.db = db
        self.model_id = model_id
        self.embedding_factory = get_embedding_factory()

    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        embeddings = []
        for text in texts:
            embedding = self.embedding_factory.generate_embedding(
                text=text, db=self.db, model_id=self.model_id
            )
            if embedding:
                embeddings.append(embedding)
            else:
                dim = len(embeddings[0]) if embeddings else 1536
                embeddings.append([0.0] * dim)
        return embeddings

    def embed_query(self, text: str) -> list[float]:
        embedding = self.embedding_factory.generate_embedding(
            text=text, db=self.db, model_id=self.model_id
        )
        return embedding if embedding else [0.0] * 1536


class MilvusVectorStore:

    def __init__(
        self,
        db: Session,
        knowledge_base_id: int | None = None,
        collection_name: str = "document_vectors",
        host: str = "localhost",
        port: str = "19530",
        use_partition: bool = True,
        tenant_id: int | None = None,
    ):
        self.db = db
        self.knowledge_base_id = knowledge_base_id
        self.collection_name = collection_name
        self.use_partition = use_partition
        self.tenant_id = tenant_id
        self.dim = self._get_embedding_dimension(db, knowledge_base_id)
        try:
            if "default" in connections.list_connections():
                logger.info("Using existing Milvus connection")
            else:
                connections.connect(alias="default", host=host, port=port)
                logger.info(f"Successfully connected to Milvus: {host}:{port}")
        except Exception as e:
            logger.error(f"Failed to connect to Milvus: {e}")
            raise VectorStoreConnectionError() from None
        self._ensure_collection()
        if self.use_partition and self.knowledge_base_id:
            self._ensure_partition()

    def _get_embedding_dimension(self, db: Session, knowledge_base_id: int | None) -> int:
        try:
            if not knowledge_base_id:
                logger.info("No knowledge_base_id provided, using default dimension: 1536")
                return 1536
            from app.models import KnowledgeBase
            from app.models.provider import Model

            kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == knowledge_base_id).first()
            if not kb or not kb.code:
                logger.warning(
                    f"Knowledge base {knowledge_base_id} not found or has no model code, using default dimension: 1536"
                )
                return 1536
            if kb.code.isdigit():
                model = db.query(Model).filter(Model.id == int(kb.code)).first()
            else:
                model = db.query(Model).filter(Model.code == kb.code).first()
            if model and model.dimension:
                logger.info(
                    f"Using model dimension from database: {model.dimension} (model: {model.name})"
                )
                return model.dimension
            else:
                logger.warning(
                    f"Model dimension not found for code {kb.code}, using default dimension: 1536"
                )
                return 1536
        except Exception as e:
            logger.error(f"Failed to get embedding dimension: {e}, using default: 1536")
            return 1536

    def _ensure_collection(self):
        try:
            if utility.has_collection(self.collection_name):
                self.collection = Collection(self.collection_name)
                logger.info(f"Using existing collection: {self.collection_name}")
            else:
                fields = [
                    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
                    FieldSchema(name="document_id", dtype=DataType.INT64),
                    FieldSchema(name="chunk_index", dtype=DataType.INT64),
                    FieldSchema(name="chunk_text", dtype=DataType.VARCHAR, max_length=65535),
                    FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.dim),
                    FieldSchema(name="model_name", dtype=DataType.VARCHAR, max_length=100),
                    FieldSchema(name="knowledge_base_id", dtype=DataType.INT64),
                    FieldSchema(name="tenant_id", dtype=DataType.INT64),
                    FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=100),
                    FieldSchema(name="parent_chunk_id", dtype=DataType.VARCHAR, max_length=100),
                    FieldSchema(name="is_parent", dtype=DataType.BOOL),
                    FieldSchema(name="chunk_level", dtype=DataType.VARCHAR, max_length=20),
                    FieldSchema(name="references", dtype=DataType.VARCHAR, max_length=2000),
                    FieldSchema(name="doc_status", dtype=DataType.VARCHAR, max_length=20),
                    FieldSchema(name="issue_date_int", dtype=DataType.INT64),
                ]
                schema = CollectionSchema(
                    fields=fields,
                    description="Document vectors for knowledge base with partition support",
                )
                self.collection = Collection(name=self.collection_name, schema=schema)
                index_params = {
                    "metric_type": "COSINE",
                    "index_type": "IVF_FLAT",
                    "params": {"nlist": 1024},
                }
                self.collection.create_index(field_name="vector", index_params=index_params)
                logger.info(f"Created new collection: {self.collection_name}")
            self.collection.load()
        except Exception as e:
            logger.error(f"Failed to ensure collection exists: {e}")
            raise

    def _ensure_partition(self):
        try:
            partition_name = f"kb_{self.knowledge_base_id}"
            if not self.collection.has_partition(partition_name):
                self.collection.create_partition(partition_name)
                logger.info(f"Created partition: {partition_name}")
            else:
                logger.info(f"Using existing partition: {partition_name}")
            self.partition_name = partition_name
        except Exception as e:
            logger.error(f"Failed to ensure partition exists: {e}")
            self.use_partition = False
            self.partition_name = None

    def drop_partition(self) -> bool:
        if not self.knowledge_base_id:
            logger.warning("drop_partition called without knowledge_base_id, skipping")
            return False
        partition_name = f"kb_{self.knowledge_base_id}"
        try:
            if self.collection.has_partition(partition_name):
                self.collection.release()
                self.collection.drop_partition(partition_name)
                logger.info(f"Dropped Milvus partition: {partition_name}")
                return True
            else:
                logger.info(f"Milvus partition {partition_name} does not exist, skipping drop")
                return False
        except Exception as e:
            logger.error(f"Failed to drop Milvus partition {partition_name}: {e}")
            raise

    def add_documents(
        self,
        document_id: int,
        chunks: list[dict[str, Any]],
        embeddings: list[list[float]],
        model_name: str,
        doc_status_list: list[str] | None = None,
        issue_date_int_list: list[int] | None = None,
    ) -> list[int]:
        try:
            chunk_ids = []
            parent_chunk_ids = []
            is_parents = []
            chunk_levels = []
            references_list = []
            for chunk in chunks:
                chunk_id = chunk.get("chunk_id", str(uuid.uuid4()))
                parent_chunk_id = chunk.get("parent_chunk_id")
                if parent_chunk_id is None:
                    parent_chunk_id = ""
                is_parent = chunk.get("is_parent", False)
                chunk_level = chunk.get("chunk_level", "leaf")
                references = chunk.get("references", [])
                if isinstance(references, list):
                    references_str = json.dumps(references) if references else ""
                else:
                    references_str = str(references) if references else ""
                chunk_ids.append(chunk_id)
                parent_chunk_ids.append(parent_chunk_id)
                is_parents.append(is_parent)
                chunk_levels.append(chunk_level)
                references_list.append(
                    references_str.encode("utf-8")[:2000].decode("utf-8", errors="ignore")
                )
            existing_fields = {f.name for f in self.collection.schema.fields}
            entities = [
                [document_id] * len(chunks),
                [chunk["chunk_index"] for chunk in chunks],
                [
                    chunk["text"].encode("utf-8")[:65535].decode("utf-8", errors="ignore")
                    for chunk in chunks
                ],
                embeddings,
                [model_name] * len(chunks),
                [self.knowledge_base_id or 0] * len(chunks),
                [self.tenant_id or 0] * len(chunks),
                chunk_ids,
                parent_chunk_ids,
                is_parents,
                chunk_levels,
                references_list,
            ]
            if "doc_status" in existing_fields:
                entities.append(
                    doc_status_list if doc_status_list is not None else ["effective"] * len(chunks)
                )
            if "issue_date_int" in existing_fields:
                entities.append(
                    issue_date_int_list if issue_date_int_list is not None else [0] * len(chunks)
                )
            if self.use_partition and self.partition_name:
                insert_result = self.collection.insert(entities, partition_name=self.partition_name)
                logger.info(
                    f"Successfully added {len(chunks)} vectors to partition {self.partition_name}"
                )
            else:
                insert_result = self.collection.insert(entities)
                logger.info(f"Successfully added {len(chunks)} vectors to document {document_id}")
            self.collection.flush()
            return insert_result.primary_keys
        except Exception as e:
            logger.error(f"Failed to add document vectors: {e}")
            raise

    def similarity_search(
        self,
        query_embedding: list[float],
        k: int = 5,
        threshold: float = 0.7,
        filter_dict: dict | None = None,
    ) -> list[tuple[dict, float]]:
        try:
            conditions = []
            if self.tenant_id:
                conditions.append(f"tenant_id == {self.tenant_id}")
            if self.knowledge_base_id:
                conditions.append(f"knowledge_base_id == {self.knowledge_base_id}")
            if filter_dict:
                for key, value in filter_dict.items():
                    if key in ("issue_date_gte", "issue_date_lte"):
                        continue
                    if isinstance(value, bool):
                        conditions.append(f"{key} == {str(value).lower()}")
                    elif isinstance(value, str):
                        conditions.append(f'{key} == "{value}"')
                    else:
                        conditions.append(f"{key} == {value}")
            try:
                field_names = [f.name for f in self.collection.schema.fields]
                if "doc_status" in field_names:
                    conditions.append('doc_status not in ["obsolete", "expired"]')
            except Exception:
                pass
            if filter_dict:
                if filter_dict.get("issue_date_gte"):
                    conditions.append(f"issue_date_int >= {filter_dict['issue_date_gte']}")
                if filter_dict.get("issue_date_lte"):
                    conditions.append(f"issue_date_int <= {filter_dict['issue_date_lte']}")
            expr = " && ".join(conditions) if conditions else ""
            search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
            search_limit = k if k > 10 else k * 2
            output_fields = [
                "document_id",
                "chunk_index",
                "chunk_text",
                "model_name",
                "chunk_id",
                "parent_chunk_id",
                "is_parent",
                "chunk_level",
                "references",
            ]
            search_kwargs = {
                "data": [query_embedding],
                "anns_field": "vector",
                "param": search_params,
                "limit": search_limit,
                "expr": expr if expr else None,
                "output_fields": output_fields,
            }
            if self.use_partition and self.partition_name:
                search_kwargs["partition_names"] = [self.partition_name]
                logger.debug(f"Searching in partition {self.partition_name}")
            results = self.collection.search(**search_kwargs)
            formatted_results = []
            for hits in results:
                for hit in hits:
                    score = float(hit.score)
                    if score >= threshold:
                        doc = (
                            self.db.query(KnowledgeDocument)
                            .filter(
                                KnowledgeDocument.id == hit.entity.get("document_id"),
                                KnowledgeDocument.status == "published",
                                KnowledgeDocument.is_vectorized,
                                KnowledgeDocument.doc_status != "obsolete",
                            )
                            .first()
                        )
                        if doc is None:
                            continue
                        references_str = hit.entity.get("references") or ""
                        try:
                            references = json.loads(references_str) if references_str else []
                        except (json.JSONDecodeError, TypeError):
                            references = []
                        doc_dict = {
                            "id": hit.id,
                            "document_id": hit.entity.get("document_id"),
                            "chunk_index": hit.entity.get("chunk_index"),
                            "text": hit.entity.get("chunk_text"),
                            "model_name": hit.entity.get("model_name"),
                            "title": doc.title if doc else "Unknown",
                            "category_id": doc.category_id if doc else None,
                            "reference_url": doc.reference_url if doc else None,
                            "chunk_id": hit.entity.get("chunk_id") or "",
                            "parent_chunk_id": hit.entity.get("parent_chunk_id") or "",
                            "is_parent": hit.entity.get("is_parent") or False,
                            "chunk_level": hit.entity.get("chunk_level") or "leaf",
                            "references": references,
                        }
                        formatted_results.append((doc_dict, score))
                    if len(formatted_results) >= k:
                        break
                if len(formatted_results) >= k:
                    break
            logger.info(f"Similarity search completed, found {len(formatted_results)} results")
            return formatted_results[:k]
        except Exception as e:
            logger.error(f"Similarity search failed: {e}")
            raise

    def hybrid_search(
        self,
        query: str,
        query_embedding: list[float],
        k: int = 5,
        threshold: float = 0.7,
        keyword_weight: float = 0.3,
        semantic_weight: float = 0.7,
    ) -> list[tuple[dict, float]]:
        try:
            search_k = k if k > 10 else k * 2
            semantic_results = self.similarity_search(
                query_embedding=query_embedding, k=search_k, threshold=threshold * 0.8
            )
            from sqlalchemy import text as sql_text

            where_clauses = [
                "dv.id IN (SELECT id FROM document_vectors WHERE to_tsvector('simple', COALESCE(chunk_text, '')) @@ plainto_tsquery('simple', :query))"
            ]
            params = {"query": query, "k": search_k}
            if self.tenant_id:
                where_clauses.append("kd.tenant_id = :tenant_id")
                params["tenant_id"] = self.tenant_id
            if self.knowledge_base_id is not None:
                where_clauses.append("kc.knowledge_base_id = :knowledge_base_id")
                params["knowledge_base_id"] = self.knowledge_base_id
            where_clauses.append("kd.status = 'published'")
            where_clauses.append("kd.is_vectorized = true")
            where_clauses.append("(kd.doc_status IS NULL OR kd.doc_status != 'obsolete')")
            keyword_query = sql_text(
                f"\n                SELECT\n                    dv.id,\n                    dv.document_id,\n                    dv.chunk_index,\n                    dv.chunk_text,\n                    dv.model_name,\n                    kd.title,\n                    kd.category_id,\n                    kd.reference_url,\n                    ts_rank(\n                        setweight(to_tsvector('simple', COALESCE(kd.title, '')), 'A') ||\n                        setweight(to_tsvector('simple', COALESCE(dv.chunk_text, '')), 'B'),\n                        plainto_tsquery('simple', :query)\n                    ) as rank\n                FROM document_vectors dv\n                JOIN knowledge_documents kd ON dv.document_id = kd.id\n                LEFT JOIN knowledge_categories kc ON kd.category_id = kc.id\n                WHERE {' AND '.join(where_clauses)}\n                ORDER BY rank DESC\n                LIMIT :k\n            "
            )
            keyword_rows = list(self.db.execute(keyword_query, params))
            max_keyword_rank = max((float(row.rank) for row in keyword_rows), default=0.0)
            combined_results = {}
            for doc, score in semantic_results:
                result_key = (doc["document_id"], doc["chunk_index"])
                combined_results[result_key] = {
                    "doc": doc,
                    "semantic_score": score,
                    "keyword_score": 0.0,
                }
            for row in keyword_rows:
                result_key = (row.document_id, row.chunk_index)
                keyword_score = float(row.rank)
                if result_key not in combined_results:
                    combined_results[result_key] = {
                        "doc": {
                            "id": row.id,
                            "document_id": row.document_id,
                            "chunk_index": row.chunk_index,
                            "text": row.chunk_text,
                            "model_name": row.model_name,
                            "title": row.title,
                            "category_id": row.category_id,
                            "reference_url": row.reference_url,
                        },
                        "semantic_score": 0.0,
                        "keyword_score": keyword_score,
                    }
                else:
                    combined_results[result_key]["keyword_score"] = keyword_score
            final_results = []
            for data in combined_results.values():
                semantic_norm = data["semantic_score"]
                keyword_norm = (
                    data["keyword_score"] / max_keyword_rank
                    if max_keyword_rank > 0 and data["keyword_score"] > 0
                    else 0.0
                )
                combined_score = semantic_weight * semantic_norm + keyword_weight * keyword_norm
                final_results.append((data["doc"], combined_score))
            final_results.sort(key=lambda x: x[1], reverse=True)
            logger.info(f"Hybrid search completed, found {len(final_results[:k])} results")
            return final_results[:k]
        except Exception as e:
            logger.error(f"Hybrid search failed: {e}")
            raise

    def delete_document_vectors(self, document_id: int) -> int:
        try:
            expr = f"document_id == {document_id}"
            self.collection.delete(expr)
            self.collection.flush()
            logger.info(f"Deleted vectors for document {document_id}")
            return 1
        except Exception as e:
            logger.error(f"Failed to delete document vectors: {e}")
            raise

    def get_collection_stats(self) -> dict[str, Any]:
        try:
            stats = {
                "total_vectors": self.collection.num_entities,
                "collection_name": self.collection_name,
                "use_partition": self.use_partition,
            }
            if self.use_partition and self.partition_name:
                partition = Partition(self.collection, self.partition_name)
                stats["partition_name"] = self.partition_name
                stats["partition_vectors"] = partition.num_entities
            return stats
        except Exception as e:
            logger.error(f"Failed to get statistics: {e}")
            return {
                "total_vectors": 0,
                "collection_name": self.collection_name,
                "use_partition": self.use_partition,
            }

    def list_partitions(self) -> list[str]:
        try:
            return self.collection.partitions
        except Exception as e:
            logger.error(f"Failed to list partitions: {e}")
            return []

    def compact_partition(self):
        try:
            if self.use_partition and self.partition_name:
                self.collection.compact()
                logger.info(f"Partition {self.partition_name} compaction completed")
        except Exception as e:
            logger.error(f"Failed to compact partition: {e}")


def get_vector_store(
    db: Session,
    knowledge_base_id: int | None = None,
    host: str = "localhost",
    port: str = "19530",
    use_partition: bool = True,
    tenant_id: int | None = None,
) -> MilvusVectorStore:
    if tenant_id is None:
        from app.core.tenant_context import get_current_tenant_id


        tenant_id = get_current_tenant_id()
    return MilvusVectorStore(
        db,
        knowledge_base_id,
        host=host,
        port=port,
        use_partition=use_partition,
        tenant_id=tenant_id,
    )
