from __future__ import annotations

import hashlib

from sqlalchemy.orm import Session

from app.models.knowledge_base import DocumentVector, KnowledgeDocument
from app.services.knowledge.vectorization_service import VectorizationService

from common_logging import get_logger

logger = get_logger(__name__)


_MIN_CHUNK_BYTES = 8


def _chunk_hash(text: str) -> str:
    return hashlib.sha256(text.encode("utf-8")).hexdigest()[:16]


class VersionDiffService:

    def __init__(self, db: Session):
        self.db = db

    def update_document_vectors(
        self,
        document_id: int,
        new_content: str,
        chunk_strategy: str = "tax_adaptive",
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        model_id: int | None = None,
        tenant_id: int | None = None,
        user_id: int | None = None,
    ) -> dict:
        doc = self.db.query(KnowledgeDocument).filter(KnowledgeDocument.id == document_id).first()
        if not doc:
            return {"success": False, "error": f"文档不存在: {document_id}"}
        logger.info(f"[VersionDiff] 开始增量更新 document_id={document_id}")
        new_chunks = self._generate_chunks(
            document=doc,
            content=new_content,
            chunk_strategy=chunk_strategy,
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
        )
        if new_chunks is None:
            return {"success": False, "error": "分块生成失败"}
        old_vectors: list[DocumentVector] = (
            self.db.query(DocumentVector)
            .filter(DocumentVector.document_id == document_id)
            .order_by(DocumentVector.chunk_index)
            .all()
        )
        old_hash_map: dict[str, DocumentVector] = {
            _chunk_hash(v.chunk_text): v for v in old_vectors if v.chunk_text
        }
        new_hash_set = set()
        unchanged_count = 0
        to_vectorize: list[tuple[int, dict]] = []
        for idx, chunk in enumerate(new_chunks):
            text = chunk.get("text", "")
            if len(text.encode("utf-8")) < _MIN_CHUNK_BYTES:
                unchanged_count += 1
                continue
            h = _chunk_hash(text)
            new_hash_set.add(h)
            if h in old_hash_map:
                old_vec = old_hash_map[h]
                old_vec.chunk_index = idx
                unchanged_count += 1
            else:
                to_vectorize.append((idx, chunk))
        to_delete_hashes = set(old_hash_map.keys()) - new_hash_set
        to_delete_vectors = [old_hash_map[h] for h in to_delete_hashes]
        removed_count = len(to_delete_vectors)
        changed_count = max(0, len(to_vectorize) - max(0, len(new_chunks) - len(old_vectors)))
        added_count = len(to_vectorize) - changed_count
        logger.info(
            f"[VersionDiff] document_id={document_id}: unchanged={unchanged_count}, to_revectorize={len(to_vectorize)}, to_delete={removed_count}"
        )
        if not to_vectorize and (not to_delete_vectors):
            self.db.commit()
            return {
                "success": True,
                "document_id": document_id,
                "unchanged": unchanged_count,
                "changed": 0,
                "added": 0,
                "removed": 0,
                "total_new": len(new_chunks),
            }
        if to_delete_vectors:
            self._delete_vectors(doc, to_delete_vectors, tenant_id)
        if to_vectorize:
            ok = self._vectorize_chunks(
                document=doc,
                chunks=[c for _, c in to_vectorize],
                chunk_indices=[idx for idx, _ in to_vectorize],
                model_id=model_id,
                tenant_id=tenant_id,
                user_id=user_id,
            )
            if not ok:
                return {"success": False, "error": "部分分块向量化失败", "document_id": document_id}
        doc.content = new_content
        import hashlib as _hl

        doc.content_hash = _hl.sha256(new_content.encode("utf-8")).hexdigest()
        doc.version_number = (doc.version_number or 1) + 1
        self.db.commit()
        logger.info(
            f"[VersionDiff] 增量更新完成 document_id={document_id}, version={doc.version_number}"
        )
        return {
            "success": True,
            "document_id": document_id,
            "unchanged": unchanged_count,
            "changed": changed_count,
            "added": added_count,
            "removed": removed_count,
            "total_new": len(new_chunks),
        }

    def _generate_chunks(
        self,
        document: KnowledgeDocument,
        content: str,
        chunk_strategy: str,
        chunk_size: int,
        chunk_overlap: int,
    ) -> list[dict] | None:
        try:
            svc = VectorizationService(self.db)
            chunks = svc.doc_processor.process_text(
                text=content,
                strategy=chunk_strategy,
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                metadata={"document_id": document.id, "title": document.title},
            )
            return chunks
        except Exception as e:
            logger.error(f"[VersionDiff] 分块生成失败: {e}".opt(exception=True))
            return None

    def _delete_vectors(
        self, document: KnowledgeDocument, vectors: list[DocumentVector], tenant_id: int | None
    ) -> None:
        from app.services.storage.vector_store import get_vector_store

        knowledge_base_id: int | None = None
        if document.category_id:
            from app.models.knowledge_base import KnowledgeCategory

            cat = (
                self.db.query(KnowledgeCategory)
                .filter(KnowledgeCategory.id == document.category_id)
                .first()
            )
            if cat:
                knowledge_base_id = cat.knowledge_base_id
        milvus_ids = [v.milvus_id for v in vectors if v.milvus_id]
        if milvus_ids and knowledge_base_id:
            try:
                vector_store = get_vector_store(
                    self.db, knowledge_base_id=knowledge_base_id, tenant_id=tenant_id
                )
                ids_expr = ", ".join(f'"{mid}"' for mid in milvus_ids)
                vector_store.collection.delete(f"pk in [{ids_expr}]")
                vector_store.collection.flush()
                logger.info(f"[VersionDiff] Milvus 删除 {len(milvus_ids)} 条过期向量")
            except Exception as e:
                logger.warning(f"[VersionDiff] Milvus 删除失败（非致命）: {e}")
        ids_to_delete = [v.id for v in vectors if v.id]
        if ids_to_delete:
            self.db.query(DocumentVector).filter(DocumentVector.id.in_(ids_to_delete)).delete(
                synchronize_session=False
            )

    def _vectorize_chunks(
        self,
        document: KnowledgeDocument,
        chunks: list[dict],
        chunk_indices: list[int],
        model_id: int | None,
        tenant_id: int | None,
        user_id: int | None,
    ) -> bool:
        try:
            from app.services.knowledge.vectorization_service import VectorizationService
            from app.services.storage.vector_store import get_vector_store

            svc = VectorizationService(self.db)
            model_name = "default"
            if model_id:
                from app.models.provider import Model

                m = self.db.query(Model).filter(Model.id == model_id).first()
                if m:
                    model_name = m.name
            chunk_texts = [c["text"] for c in chunks]
            embeddings = svc.embedding_factory.generate_embeddings_batch(
                texts=chunk_texts, db=self.db, model_id=model_id
            )
            knowledge_base_id: int | None = None
            if document.category_id:
                from app.models.knowledge_base import KnowledgeCategory


                cat = (
                    self.db.query(KnowledgeCategory)
                    .filter(KnowledgeCategory.id == document.category_id)
                    .first()
                )
                if cat:
                    knowledge_base_id = cat.knowledge_base_id
            vector_store = get_vector_store(
                self.db, knowledge_base_id=knowledge_base_id, tenant_id=tenant_id
            )
            vector_ids = vector_store.add_documents(
                document_id=document.id, chunks=chunks, embeddings=embeddings, model_name=model_name
            )
            for local_i, (orig_idx, chunk) in enumerate(zip(chunk_indices, chunks, strict=False)):
                dv = DocumentVector(
                    document_id=document.id,
                    chunk_index=orig_idx,
                    chunk_text=chunk["text"],
                    milvus_id=str(vector_ids[local_i]) if local_i < len(vector_ids) else None,
                    model_name=model_name,
                    parent_chunk_id=chunk.get("parent_chunk_id"),
                    is_parent=chunk.get("is_parent", False),
                    chunk_level=chunk.get("chunk_level"),
                    doc_type=document.doc_type,
                    doc_number=document.doc_number,
                    issuing_authority=document.issuing_authority,
                    references=chunk.get("references"),
                )
                self.db.add(dv)
            self.db.flush()
            logger.info(f"[VersionDiff] 新增/变化分块向量化完成: {len(chunks)} 条")
            return True
        except Exception as e:
            logger.error(f"[VersionDiff] 分块向量化失败: {e}".opt(exception=True))
            return False
