from pathlib import Path
from typing import Any

from langchain_community.document_loaders import (

    PyPDFLoader,
    TextLoader,
    UnstructuredHTMLLoader,
    UnstructuredMarkdownLoader,
    UnstructuredWordDocumentLoader,
)
from langchain_core.documents import Document as LangChainDocument
from langchain_text_splitters import (
    CharacterTextSplitter,
    RecursiveCharacterTextSplitter,
    TokenTextSplitter,
)

from common_logging import get_logger

logger = get_logger(__name__)



class DocumentProcessor:
    LOADER_MAPPING = {
        ".txt": TextLoader,
        ".pdf": PyPDFLoader,
        ".docx": UnstructuredWordDocumentLoader,
        ".doc": UnstructuredWordDocumentLoader,
        ".md": UnstructuredMarkdownLoader,
        ".html": UnstructuredHTMLLoader,
        ".htm": UnstructuredHTMLLoader,
    }
    SPLITTER_CONFIGS = {
        "recursive": {
            "class": RecursiveCharacterTextSplitter,
            "default_params": {
                "chunk_size": 1000,
                "chunk_overlap": 200,
                "separators": ["\n\n", "\n", "。", "！", "？", ".", "!", "?", " ", ""],
            },
        },
        "character": {
            "class": CharacterTextSplitter,
            "default_params": {"chunk_size": 1000, "chunk_overlap": 200, "separator": "\n"},
        },
        "token": {
            "class": TokenTextSplitter,
            "default_params": {"chunk_size": 500, "chunk_overlap": 50},
        },
    }

    def __init__(self):
        self.default_splitter = "recursive"
        self.default_chunk_size = 1000
        self.default_chunk_overlap = 200

    def load_document(self, file_path: str) -> list[LangChainDocument]:
        try:
            path = Path(file_path)
            file_extension = path.suffix.lower()
            if file_extension not in self.LOADER_MAPPING:
                raise ValueError(f"Unsupported file type: {file_extension}")
            loader_class = self.LOADER_MAPPING[file_extension]
            loader = loader_class(file_path)
            documents = loader.load()
            logger.info(
                f"Successfully loaded document: {file_path}, pages/paragraphs: {len(documents)}"
            )
            return documents
        except Exception as e:
            logger.error(f"Failed to load document {file_path}: {e}")
            raise

    def load_text(self, text: str, metadata: dict | None = None) -> list[LangChainDocument]:
        doc = LangChainDocument(page_content=text, metadata=metadata or {})
        return [doc]

    def split_documents(
        self,
        documents: list[LangChainDocument],
        strategy: str = "recursive",
        chunk_size: int | None = None,
        chunk_overlap: int | None = None,
        **kwargs,
    ) -> list[LangChainDocument]:
        try:
            if strategy in ["tax_article", "tax_clause", "tax_chapter"]:
                from app.services.knowledge.text_splitter_service import (
                    SplitterType,
                    TextSplitterService,
                )


                result_chunks = []
                for doc in documents:
                    splitter_type = (
                        SplitterType.TAX_ARTICLE
                        if strategy == "tax_article"
                        else (
                            SplitterType.TAX_CLAUSE
                            if strategy == "tax_clause"
                            else SplitterType.TAX_CHAPTER
                        )
                    )
                    chunks = TextSplitterService.split_text(
                        text=doc.page_content,
                        chunk_size=chunk_size or 1000,
                        chunk_overlap=chunk_overlap or 200,
                        splitter_type=splitter_type,
                        document_title=doc.metadata.get("title"),
                    )
                    for chunk_text in chunks:
                        result_chunks.append(
                            LangChainDocument(page_content=chunk_text, metadata=doc.metadata.copy())
                        )
                logger.info(
                    f"Tax document splitting completed: {len(documents)} docs -> {len(result_chunks)} chunks"
                )
                return result_chunks
            if strategy not in self.SPLITTER_CONFIGS:
                logger.warning(
                    f"Unknown chunking strategy: {strategy}, using default strategy: {self.default_splitter}"
                )
                strategy = self.default_splitter
            config = self.SPLITTER_CONFIGS[strategy]
            splitter_class = config["class"]
            default_params = config["default_params"].copy()
            if chunk_size is not None:
                default_params["chunk_size"] = chunk_size
            if chunk_overlap is not None:
                default_params["chunk_overlap"] = chunk_overlap
            default_params.update(kwargs)
            splitter = splitter_class(**default_params)
            chunks = splitter.split_documents(documents)
            logger.info(
                f"Document chunking completed: {len(documents)} documents -> {len(chunks)} chunks"
            )
            return chunks
        except Exception as e:
            logger.error(f"Document chunking failed: {e}")
            raise

    def process_file(
        self,
        file_path: str,
        strategy: str = "recursive",
        chunk_size: int | None = None,
        chunk_overlap: int | None = None,
        metadata: dict | None = None,
        **kwargs,
    ) -> list[dict[str, Any]]:
        try:
            documents = self.load_document(file_path)
            if metadata:
                for doc in documents:
                    doc.metadata.update(metadata)
            chunks = self.split_documents(
                documents,
                strategy=strategy,
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                **kwargs,
            )
            result = []
            for idx, chunk in enumerate(chunks):
                result.append(
                    {"text": chunk.page_content, "metadata": chunk.metadata, "chunk_index": idx}
                )
            return result
        except Exception as e:
            logger.error(f"Failed to process file {file_path}: {e}")
            raise

    def process_text(
        self,
        text: str,
        strategy: str = "recursive",
        chunk_size: int | None = None,
        chunk_overlap: int | None = None,
        metadata: dict | None = None,
        **kwargs,
    ) -> list[dict[str, Any]]:
        try:
            documents = self.load_text(text, metadata)
            chunks = self.split_documents(
                documents,
                strategy=strategy,
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                **kwargs,
            )
            result = []
            for idx, chunk in enumerate(chunks):
                result.append(
                    {"text": chunk.page_content, "metadata": chunk.metadata, "chunk_index": idx}
                )
            return result
        except Exception as e:
            logger.error(f"Failed to process text: {e}")
            raise

    def preview_chunking(
        self,
        text: str,
        strategy: str = "recursive",
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        **kwargs,
    ) -> dict[str, Any]:
        try:
            chunks = self.process_text(
                text,
                strategy=strategy,
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                **kwargs,
            )
            chunk_lengths = [len(chunk["text"]) for chunk in chunks]
            return {
                "total_chunks": len(chunks),
                "avg_chunk_length": sum(chunk_lengths) / len(chunk_lengths) if chunks else 0,
                "min_chunk_length": min(chunk_lengths) if chunks else 0,
                "max_chunk_length": max(chunk_lengths) if chunks else 0,
                "chunks": chunks[:5],
                "strategy": strategy,
                "chunk_size": chunk_size,
                "chunk_overlap": chunk_overlap,
            }
        except Exception as e:
            logger.error(f"Failed to preview chunking: {e}")
            raise


_document_processor = None


def get_document_processor() -> DocumentProcessor:
    global _document_processor
    if _document_processor is None:
        _document_processor = DocumentProcessor()
    return _document_processor
