from typing import Any

from langchain_core.embeddings import Embeddings
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
from sqlalchemy.orm import Session

from app.models import KnowledgeDocument
from app.services.llm.backends.embedding_backend_factory import get_embedding_factory

from common_logging import get_logger

logger = get_logger(__name__)




class CustomEmbeddings(Embeddings):

    def __init__(self, db: Session, model_id: int | None = None):
        self.db = db
        self.model_id = model_id
        self.embedding_factory = get_embedding_factory()

    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        embeddings = []
        for text in texts:
            embedding = self.embedding_factory.generate_embedding(
                text=text, db=self.db, model_id=self.model_id
            )
            if embedding:
                embeddings.append(embedding)
            else:
                embeddings.append([0.0] * 1024)
        return embeddings

    def embed_query(self, text: str) -> list[float]:
        embedding = self.embedding_factory.generate_embedding(
            text=text, db=self.db, model_id=self.model_id
        )
        return embedding if embedding else [0.0] * 1024


class MilvusBasicStore:

    def __init__(
        self,
        db: Session,
        knowledge_base_id: int | None = None,
        collection_name: str = "document_vectors",
        host: str = "localhost",
        port: str = "19530",
        use_lite: bool = True,
    ):
        self.db = db
        self.knowledge_base_id = knowledge_base_id
        self.collection_name = collection_name
        self.use_lite = use_lite
        self.dim = self._get_embedding_dimension(db, knowledge_base_id)
        try:
            if "default" in connections.list_connections():
                logger.info("使用现有的 Milvus 连接")
            elif use_lite:
                connections.connect(alias="default", uri="./data/milvus_data.db")
                logger.info("成功连接到 Milvus Lite (本地存储)")
            else:
                connections.connect(alias="default", host=host, port=port)
                logger.info(f"成功连接到Milvus: {host}:{port}")
        except Exception as e:
            logger.error(f"连接Milvus失败: {e}")
            raise
        self._ensure_collection()

    def _get_embedding_dimension(self, db: Session, knowledge_base_id: int | None) -> int:
        try:
            if knowledge_base_id:
                from app.models import KnowledgeBase, Model

                kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == knowledge_base_id).first()
                if kb and kb.code:
                    if kb.code.isdigit():
                        model = db.query(Model).filter(Model.id == int(kb.code)).first()
                    else:
                        model = db.query(Model).filter(Model.code == kb.code).first()
                    if model and model.dimension:
                        logger.info(
                            f"Using dimension {model.dimension} for knowledge base {knowledge_base_id} (model: {model.name})"
                        )
                        return model.dimension
            logger.info(f"Using default dimension 2560 for knowledge base {knowledge_base_id}")
            return 2560
        except Exception as e:
            logger.error(f"Failed to get embedding dimension: {e}")
            return 2560

    def _get_embedding_dimension(self, db: Session, knowledge_base_id: int | None) -> int:
        try:
            if knowledge_base_id:
                from app.models import KnowledgeBase, Model

                kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == knowledge_base_id).first()
                if kb and kb.code:
                    if kb.code.isdigit():
                        model = db.query(Model).filter(Model.id == int(kb.code)).first()
                    else:
                        model = db.query(Model).filter(Model.code == kb.code).first()
                    if model and model.dimension:
                        logger.info(
                            f"Using dimension {model.dimension} for knowledge base {knowledge_base_id} (model: {model.name})"
                        )
                        return model.dimension
            logger.info(f"Using default dimension 2560 for knowledge base {knowledge_base_id}")
            return 2560
        except Exception as e:
            logger.error(f"Failed to get embedding dimension: {e}")
            return 2560

    def _ensure_collection(self):
        try:
            if utility.has_collection(self.collection_name):
                self.collection = Collection(self.collection_name)
                logger.info(f"使用现有集合: {self.collection_name}")
            else:
                fields = [
                    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
                    FieldSchema(name="document_id", dtype=DataType.INT64),
                    FieldSchema(name="chunk_index", dtype=DataType.INT64),
                    FieldSchema(name="chunk_text", dtype=DataType.VARCHAR, max_length=65535),
                    FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.dim),
                    FieldSchema(name="model_name", dtype=DataType.VARCHAR, max_length=100),
                    FieldSchema(name="knowledge_base_id", dtype=DataType.INT64),
                ]
                schema = CollectionSchema(
                    fields=fields, description="Document vectors for knowledge base"
                )
                self.collection = Collection(name=self.collection_name, schema=schema)
                index_params = {
                    "metric_type": "COSINE",
                    "index_type": "IVF_FLAT",
                    "params": {"nlist": 1024},
                }
                self.collection.create_index(field_name="vector", index_params=index_params)
                logger.info(f"创建新集合: {self.collection_name}")
            self.collection.load()
        except Exception as e:
            logger.error(f"确保集合存在失败: {e}")
            raise

    def add_documents(
        self,
        document_id: int,
        chunks: list[dict[str, Any]],
        embeddings: list[list[float]],
        model_name: str,
    ) -> list[int]:
        try:
            entities = [
                [document_id] * len(chunks),
                [chunk["chunk_index"] for chunk in chunks],
                [chunk["text"] for chunk in chunks],
                embeddings,
                [model_name] * len(chunks),
                [self.knowledge_base_id or 0] * len(chunks),
            ]
            insert_result = self.collection.insert(entities)
            self.collection.flush()
            logger.info(f"成功添加 {len(chunks)} 个向量到文档 {document_id}")
            return insert_result.primary_keys
        except Exception as e:
            logger.error(f"添加文档向量失败: {e}")
            raise

    def similarity_search(
        self,
        query_embedding: list[float],
        k: int = 5,
        threshold: float = 0.7,
        filter_dict: dict | None = None,
    ) -> list[tuple[dict, float]]:
        try:
            expr = ""
            if self.knowledge_base_id:
                expr = f"knowledge_base_id == {self.knowledge_base_id}"
            if filter_dict:
                if expr:
                    expr += " && "
                for key, value in filter_dict.items():
                    if isinstance(value, str):
                        expr += f'{key} == "{value}" && '
                    else:
                        expr += f"{key} == {value} && "
                if expr.endswith(" && "):
                    expr = expr[:-4]
            search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
            results = self.collection.search(
                data=[query_embedding],
                anns_field="vector",
                param=search_params,
                limit=k * 2,
                expr=expr if expr else None,
                output_fields=["document_id", "chunk_index", "chunk_text", "model_name"],
            )
            formatted_results = []
            for hits in results:
                for hit in hits:
                    score = float(hit.score)
                    if score >= threshold:
                        doc = (
                            self.db.query(KnowledgeDocument)
                            .filter(KnowledgeDocument.id == hit.entity.get("document_id"))
                            .first()
                        )
                        doc_dict = {
                            "id": hit.id,
                            "document_id": hit.entity.get("document_id"),
                            "chunk_index": hit.entity.get("chunk_index"),
                            "text": hit.entity.get("chunk_text"),
                            "model_name": hit.entity.get("model_name"),
                            "title": doc.title if doc else "Unknown",
                            "category_id": doc.category_id if doc else None,
                        }
                        formatted_results.append((doc_dict, score))
                    if len(formatted_results) >= k:
                        break
                if len(formatted_results) >= k:
                    break
            logger.info(f"相似度搜索完成，找到 {len(formatted_results)} 个结果")
            return formatted_results[:k]
        except Exception as e:
            logger.error(f"相似度搜索失败: {e}")
            raise

    def hybrid_search(
        self,
        query: str,
        query_embedding: list[float],
        k: int = 5,
        threshold: float = 0.7,
        keyword_weight: float = 0.3,
        semantic_weight: float = 0.7,
    ) -> list[tuple[dict, float]]:
        try:
            semantic_results = self.similarity_search(
                query_embedding=query_embedding, k=k * 2, threshold=threshold * 0.8
            )
            from sqlalchemy import text as sql_text


            keyword_query = sql_text(
                "\n                SELECT\n                    id,\n                    title,\n                    content,\n                    category_id,\n                    ts_rank(to_tsvector('simple', content), plainto_tsquery('simple', :query)) as rank\n                FROM knowledge_documents\n                WHERE to_tsvector('simple', content) @@ plainto_tsquery('simple', :query)\n                ORDER BY rank DESC\n                LIMIT :k\n            "
            )
            keyword_result = self.db.execute(keyword_query, {"query": query, "k": k * 2})
            keyword_results = {}
            for row in keyword_result:
                keyword_results[row.id] = {"score": float(row.rank)}
            combined_results = {}
            for doc, score in semantic_results:
                doc_id = doc["document_id"]
                combined_results[doc_id] = {
                    "doc": doc,
                    "semantic_score": score,
                    "keyword_score": 0.0,
                }
            for doc_id, data in keyword_results.items():
                if doc_id in combined_results:
                    combined_results[doc_id]["keyword_score"] = data["score"]
            final_results = []
            for _doc_id, data in combined_results.items():
                semantic_norm = data["semantic_score"]
                keyword_norm = (
                    min(data["keyword_score"] / 0.1, 1.0) if data["keyword_score"] > 0 else 0.0
                )
                combined_score = semantic_weight * semantic_norm + keyword_weight * keyword_norm
                final_results.append((data["doc"], combined_score))
            final_results.sort(key=lambda x: x[1], reverse=True)
            logger.info(f"混合搜索完成，找到 {len(final_results[:k])} 个结果")
            return final_results[:k]
        except Exception as e:
            logger.error(f"混合搜索失败: {e}")
            raise

    def delete_document_vectors(self, document_id: int) -> int:
        try:
            expr = f"document_id == {document_id}"
            self.collection.delete(expr)
            self.collection.flush()
            logger.info(f"删除文档 {document_id} 的向量")
            return 1
        except Exception as e:
            logger.error(f"删除文档向量失败: {e}")
            raise

    def get_collection_stats(self) -> dict[str, Any]:
        try:
            stats = self.collection.num_entities
            return {"total_vectors": stats, "collection_name": self.collection_name}
        except Exception as e:
            logger.error(f"获取统计信息失败: {e}")
            return {"total_vectors": 0, "collection_name": self.collection_name}


def get_vector_store(
    db: Session,
    knowledge_base_id: int | None = None,
    host: str = "localhost",
    port: str = "19530",
    use_lite: bool = True,
) -> MilvusBasicStore:
    return MilvusBasicStore(db, knowledge_base_id, host=host, port=port, use_lite=use_lite)
