from enum import Enum

from common_logging import get_logger

logger = get_logger(__name__)




class SplitterType(str, Enum):
    RECURSIVE = "recursive"
    CHARACTER = "character"
    TOKEN = "token"
    MARKDOWN = "markdown"
    TAX_ADAPTIVE = "tax_adaptive"


class TextSplitterService:

    @staticmethod
    def split_text(
        text: str,
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        splitter_type: SplitterType = SplitterType.RECURSIVE,
        separator: str = "\n\n",
        document_title: str | None = None,
        document_id: str | None = None,
        document_number: str | None = None,
    ):
        try:
            if splitter_type == SplitterType.RECURSIVE:
                return TextSplitterService._split_recursive(text, chunk_size, chunk_overlap)
            elif splitter_type == SplitterType.CHARACTER:
                return TextSplitterService._split_character(
                    text, chunk_size, chunk_overlap, separator
                )
            elif splitter_type == SplitterType.TOKEN:
                return TextSplitterService._split_token(text, chunk_size, chunk_overlap)
            elif splitter_type == SplitterType.MARKDOWN:
                return TextSplitterService._split_markdown(text, chunk_size, chunk_overlap)
            elif splitter_type == SplitterType.TAX_ADAPTIVE:
                return TextSplitterService._split_tax_adaptive(
                    text, document_id or "unknown", document_number or "", document_title
                )
            else:
                logger.warning(
                    f"Unknown splitter type: {splitter_type}, using default recursive splitter"
                )
                return TextSplitterService._split_recursive(text, chunk_size, chunk_overlap)
        except Exception as e:
            logger.error(f"Text splitting failed: {e}")
            return [text]

    @staticmethod
    def _split_recursive(text: str, chunk_size: int, chunk_overlap: int) -> list[str]:
        from langchain_text_splitters import RecursiveCharacterTextSplitter

        splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len,
            separators=["\n\n", "\n", "。", "！", "？", ".", "!", "?", " ", ""],
        )
        chunks = splitter.split_text(text)
        logger.info(
            f"Recursive splitting completed: {len(text)} characters -> {len(chunks)} chunks"
        )
        return chunks

    @staticmethod
    def _split_character(
        text: str, chunk_size: int, chunk_overlap: int, separator: str
    ) -> list[str]:
        from langchain_text_splitters import CharacterTextSplitter

        splitter = CharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separator=separator,
            length_function=len,
        )
        chunks = splitter.split_text(text)
        logger.info(
            f"Character splitting completed: {len(text)} characters -> {len(chunks)} chunks"
        )
        return chunks

    @staticmethod
    def _split_token(text: str, chunk_size: int, chunk_overlap: int) -> list[str]:
        from langchain_text_splitters import TokenTextSplitter

        splitter = TokenTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
        chunks = splitter.split_text(text)
        logger.info(f"Token splitting completed: {len(text)} characters -> {len(chunks)} chunks")
        return chunks

    @staticmethod
    def _split_markdown(text: str, chunk_size: int, chunk_overlap: int) -> list[str]:
        from langchain_text_splitters import MarkdownTextSplitter

        splitter = MarkdownTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
        chunks = splitter.split_text(text)
        logger.info(f"Markdown splitting completed: {len(text)} characters -> {len(chunks)} chunks")
        return chunks

    @staticmethod
    def _split_tax_adaptive(
        text: str, document_id: str, document_number: str, document_title: str | None
    ) -> list[dict]:
        from app.services.knowledge.tax_adaptive_splitter import TaxAdaptiveSplitter

        splitter = TaxAdaptiveSplitter()
        doc_type = "law"
        if document_title:
            if "公告" in document_title or "通知" in document_title:
                doc_type = "announcement"
            elif "案例" in document_title or "判决" in document_title:
                doc_type = "case"
        tax_chunks = splitter.split(
            text=text, document_id=document_id, document_number=document_number, doc_type=doc_type
        )
        chunks = []
        for chunk in tax_chunks:
            chunks.append(
                {
                    "text": chunk.text,
                    "chunk_id": chunk.chunk_id,
                    "is_parent": chunk.is_parent,
                    "parent_chunk_id": chunk.parent_chunk_id,
                    "chunk_level": chunk.chunk_level,
                    "chunk_index": chunk.chunk_index,
                    "references": chunk.references,
                    "metadata": chunk.metadata,
                }
            )
        logger.info(
            f"Tax adaptive splitting completed: {len(text)} characters -> {len(chunks)} chunks ({sum(1 for c in chunks if c['is_parent'])} parents + {sum(1 for c in chunks if not c['is_parent'])} children)"
        )
        return chunks

    @staticmethod
    def _split_tax(
        text: str,
        granularity: str,
        chunk_size: int,
        chunk_overlap: int,
        document_title: str | None,
    ) -> list[str]:
        from app.services.knowledge.tax_document_splitter import TaxDocumentSplitter


        splitter = TaxDocumentSplitter(
            granularity=granularity,
            add_context=True,
            document_title=document_title,
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
        )
        chunks = splitter.split_text(text)
        logger.info(
            f"Tax document splitting completed: {len(text)} characters -> {len(chunks)} chunks (granularity: {granularity})"
        )
        return chunks

    @staticmethod
    def create_chunks_with_metadata(
        text: str,
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        splitter_type: SplitterType = SplitterType.RECURSIVE,
        document_id: int | None = None,
        document_title: str | None = None,
        document_number: str | None = None,
        window_size: int | None = None,
    ) -> list[dict]:
        if splitter_type == SplitterType.TAX_ADAPTIVE:
            chunks = TextSplitterService.split_text(
                text=text,
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                splitter_type=splitter_type,
                document_title=document_title,
                document_id=str(document_id) if document_id else None,
                document_number=document_number,
            )
            return chunks
        chunks = TextSplitterService.split_text(
            text=text,
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            splitter_type=splitter_type,
            document_title=document_title,
            document_id=str(document_id) if document_id else None,
            document_number=document_number,
        )
        result = []
        for i, chunk in enumerate(chunks):
            chunk_data = {
                "text": chunk,
                "chunk_index": i,
                "total_chunks": len(chunks),
                "chunk_size": len(chunk),
                "metadata": {
                    "document_id": document_id,
                    "document_title": document_title,
                    "chunk_index": i,
                    "total_chunks": len(chunks),
                },
            }
            result.append(chunk_data)
        return result


_text_splitter_service = None


def get_text_splitter_service() -> TextSplitterService:
    global _text_splitter_service
    if _text_splitter_service is None:
        _text_splitter_service = TextSplitterService()
    return _text_splitter_service
