import re
import uuid
from dataclasses import dataclass, field
from typing import Any

from common_logging import get_logger

logger = get_logger(__name__)


@dataclass
class PropositionChunk:
    chunk_id: str
    text: str
    parent_id: str | None = None
    level: str = "sentence"
    article_number: str | None = None
    chapter_number: str | None = None
    char_count: int = 0
    prev_chunk_id: str | None = None
    next_chunk_id: str | None = None
    metadata: dict[str, Any] = field(default_factory=dict)

    def __post_init__(self):
        if self.char_count == 0:
            self.char_count = len(self.text)


class AdaptivePropositionSplitter:
    CHAPTER_PATTERN = re.compile("第[一二三四五六七八九十百千万]+章")
    ARTICLE_PATTERN = re.compile("第[一二三四五六七八九十百千万]+条")
    CLAUSE_PATTERN = re.compile(
        "第[一二三四五六七八九十百千万]+款|（[一二三四五六七八九十百千万]+）|\\([一二三四五六七八九十百千万]+\\)"
    )

    def __init__(self, min_size: int = 200, max_size: int = 2000, target_size: int = 800):
        self.min_size = min_size
        self.max_size = max_size
        self.target_size = target_size

    def split(self, text: str, doc_type: str = "law") -> list[PropositionChunk]:
        if not text or not text.strip():
            return []
        text, table_placeholders = self._extract_tables(text)
        if doc_type == "law":
            chunks = self._split_law_document(text)
        elif doc_type == "announcement":
            chunks = self._split_by_paragraph(text)
        elif doc_type in ("case", "practice"):
            chunks = self._split_by_heading(text)
        else:
            chunks = self._split_by_paragraph(text)
        if table_placeholders:
            chunks = self._restore_tables(chunks, table_placeholders)
        self._link_chunks(chunks)
        logger.bind(doc_id=doc_type, chunk_count=len(chunks)).info("document split completed")
        return chunks

    def _extract_tables(self, text: str):
        import re

        placeholders = {}
        counter = [0]
        heading_pattern = re.compile(
            "(第[一二三四五六七八九十百千万\\d]+[章条][\\s\u3000]*[^\\n]{0,30})"
        )

        def replace_table(m):
            idx = counter[0]
            counter[0] += 1
            placeholder = f"\n\n__TABLE_{idx}__\n\n"
            prefix_text = text[: m.start()]
            heading_matches = list(heading_pattern.finditer(prefix_text))
            context_prefix = heading_matches[-1].group(0).strip() if heading_matches else ""
            placeholders[f"__TABLE_{idx}__"] = {
                "html": m.group(0),
                "context_prefix": context_prefix,
            }
            return placeholder

        processed = re.sub("<table[\\s\\S]*?</table>", replace_table, text, flags=re.IGNORECASE)
        return (processed, placeholders)

    def _restore_tables(
        self, chunks: list[PropositionChunk], placeholders: dict
    ) -> list[PropositionChunk]:
        import re

        result = []
        for chunk in chunks:
            m = re.fullmatch("\\s*(__TABLE_\\d+__)\\s*", chunk.text)
            if m:
                key = m.group(1)
                info = placeholders.get(key)
                if info:
                    prefix = f"【{info['context_prefix']}】\n\n" if info["context_prefix"] else ""
                    chunk.text = prefix + info["html"]
                    chunk.char_count = len(chunk.text)
                    chunk.metadata["is_table"] = True
                    result.append(chunk)
                continue

            def inline_replace(m2):
                key = m2.group(1)
                info = placeholders.get(key, {})
                return "\n\n" + info.get("html", "") + "\n\n"

            new_text = re.sub("(__TABLE_\\d+__)", inline_replace, chunk.text)
            if new_text != chunk.text:
                chunk.text = new_text
                chunk.char_count = len(chunk.text)
            result.append(chunk)
        return result

    def _split_law_document(self, text: str) -> list[PropositionChunk]:
        chunks = []
        chapter_matches = list(self.CHAPTER_PATTERN.finditer(text))
        if chapter_matches:
            chunks = self._split_by_chapter(text, chapter_matches)
        else:
            article_matches = list(self.ARTICLE_PATTERN.finditer(text))
            if article_matches:
                chunks = self._split_by_article(text, article_matches)
            else:
                chunks = self._split_by_paragraph(text)
        return chunks

    def _split_by_chapter(
        self, text: str, chapter_matches: list[re.Match]
    ) -> list[PropositionChunk]:
        chunks = []
        if chapter_matches[0].start() > 0:
            preamble = text[: chapter_matches[0].start()].strip()
            if preamble:
                chunk_id = str(uuid.uuid4())
                chunks.append(
                    PropositionChunk(
                        chunk_id=chunk_id,
                        text=preamble,
                        level="preamble",
                        metadata={"type": "preamble"},
                    )
                )
        for i, match in enumerate(chapter_matches):
            start = match.start()
            end = chapter_matches[i + 1].start() if i + 1 < len(chapter_matches) else len(text)
            chapter_text = text[start:end].strip()
            chapter_number = match.group()
            if len(chapter_text) > self.max_size:
                sub_chunks = self._split_chapter_by_article(chapter_text, chapter_number)
                chunks.extend(sub_chunks)
            else:
                chunk_id = str(uuid.uuid4())
                chunks.append(
                    PropositionChunk(
                        chunk_id=chunk_id,
                        text=chapter_text,
                        level="chapter",
                        chapter_number=chapter_number,
                        metadata={"chapter": chapter_number},
                    )
                )
        return chunks

    def _split_chapter_by_article(
        self, chapter_text: str, chapter_number: str
    ) -> list[PropositionChunk]:
        chunks = []
        parent_id = str(uuid.uuid4())
        article_matches = list(self.ARTICLE_PATTERN.finditer(chapter_text))
        if not article_matches:
            chunks.append(
                PropositionChunk(
                    chunk_id=parent_id,
                    text=chapter_text,
                    level="chapter",
                    chapter_number=chapter_number,
                    metadata={"chapter": chapter_number},
                )
            )
            return chunks
        if article_matches[0].start() > 0:
            header = chapter_text[: article_matches[0].start()].strip()
            if header:
                chunk_id = str(uuid.uuid4())
                chunks.append(
                    PropositionChunk(
                        chunk_id=chunk_id,
                        text=header,
                        parent_id=parent_id,
                        level="chapter_header",
                        chapter_number=chapter_number,
                        metadata={"chapter": chapter_number, "type": "header"},
                    )
                )
        for i, match in enumerate(article_matches):
            start = match.start()
            end = (
                article_matches[i + 1].start()
                if i + 1 < len(article_matches)
                else len(chapter_text)
            )
            article_text = chapter_text[start:end].strip()
            article_number = match.group()
            if len(article_text) > self.max_size:
                sub_chunks = self._split_article_by_clause(
                    article_text, article_number, chapter_number, parent_id
                )
                chunks.extend(sub_chunks)
            else:
                chunk_id = str(uuid.uuid4())
                chunks.append(
                    PropositionChunk(
                        chunk_id=chunk_id,
                        text=article_text,
                        parent_id=parent_id,
                        level="article",
                        article_number=article_number,
                        chapter_number=chapter_number,
                        metadata={"article": article_number, "chapter": chapter_number},
                    )
                )
        return chunks

    def _split_by_article(
        self, text: str, article_matches: list[re.Match]
    ) -> list[PropositionChunk]:
        chunks = []
        if article_matches[0].start() > 0:
            preamble = text[: article_matches[0].start()].strip()
            if preamble:
                chunk_id = str(uuid.uuid4())
                chunks.append(
                    PropositionChunk(
                        chunk_id=chunk_id,
                        text=preamble,
                        level="preamble",
                        metadata={"type": "preamble"},
                    )
                )
        for i, match in enumerate(article_matches):
            start = match.start()
            end = article_matches[i + 1].start() if i + 1 < len(article_matches) else len(text)
            article_text = text[start:end].strip()
            article_number = match.group()
            if len(article_text) > self.max_size:
                sub_chunks = self._split_article_by_clause(article_text, article_number, None, None)
                chunks.extend(sub_chunks)
            else:
                chunk_id = str(uuid.uuid4())
                chunks.append(
                    PropositionChunk(
                        chunk_id=chunk_id,
                        text=article_text,
                        level="article",
                        article_number=article_number,
                        metadata={"article": article_number},
                    )
                )
        return chunks

    def _split_article_by_clause(
        self,
        article_text: str,
        article_number: str,
        chapter_number: str | None,
        parent_id: str | None,
    ) -> list[PropositionChunk]:
        chunks = []
        clause_matches = list(self.CLAUSE_PATTERN.finditer(article_text))
        if not clause_matches:
            paragraphs = self._split_by_paragraph_raw(article_text)
            for para in paragraphs:
                chunk_id = str(uuid.uuid4())
                chunks.append(
                    PropositionChunk(
                        chunk_id=chunk_id,
                        text=para,
                        parent_id=parent_id,
                        level="article",
                        article_number=article_number,
                        chapter_number=chapter_number,
                        metadata={"article": article_number, "chapter": chapter_number},
                    )
                )
            return chunks
        if clause_matches[0].start() > 0:
            header = article_text[: clause_matches[0].start()].strip()
            if header:
                chunk_id = str(uuid.uuid4())
                chunks.append(
                    PropositionChunk(
                        chunk_id=chunk_id,
                        text=header,
                        parent_id=parent_id,
                        level="article_header",
                        article_number=article_number,
                        chapter_number=chapter_number,
                        metadata={
                            "article": article_number,
                            "chapter": chapter_number,
                            "type": "header",
                        },
                    )
                )
        for i, match in enumerate(clause_matches):
            start = match.start()
            end = (
                clause_matches[i + 1].start() if i + 1 < len(clause_matches) else len(article_text)
            )
            clause_text = article_text[start:end].strip()
            chunk_id = str(uuid.uuid4())
            chunks.append(
                PropositionChunk(
                    chunk_id=chunk_id,
                    text=clause_text,
                    parent_id=parent_id,
                    level="clause",
                    article_number=article_number,
                    chapter_number=chapter_number,
                    metadata={"article": article_number, "chapter": chapter_number},
                )
            )
        return chunks

    def _split_by_paragraph(self, text: str) -> list[PropositionChunk]:
        paragraphs = self._split_by_paragraph_raw(text)
        merged = self._adaptive_merge(paragraphs)
        chunks = []
        for para in merged:
            chunk_id = str(uuid.uuid4())
            chunks.append(
                PropositionChunk(
                    chunk_id=chunk_id, text=para, level="paragraph", metadata={"type": "paragraph"}
                )
            )
        return chunks

    def _split_by_paragraph_raw(self, text: str) -> list[str]:
        paragraphs = re.split("\\n\\n+|\\n", text)
        return [p.strip() for p in paragraphs if p.strip()]

    def _split_by_heading(self, text: str) -> list[PropositionChunk]:
        chunks = []
        paragraphs = self._split_by_paragraph_raw(text)
        merged = self._adaptive_merge(paragraphs)
        for para in merged:
            chunk_id = str(uuid.uuid4())
            chunks.append(
                PropositionChunk(
                    chunk_id=chunk_id, text=para, level="section", metadata={"type": "section"}
                )
            )
        return chunks

    def _adaptive_merge(self, chunks: list[str]) -> list[str]:
        if not chunks:
            return []
        merged = []
        current = ""
        for chunk in chunks:
            if not current:
                current = chunk
            elif len(current) + len(chunk) + 1 <= self.target_size:
                current += "\n" + chunk
            elif len(current) >= self.min_size or not merged:
                merged.append(current)
                current = chunk
            else:
                current += "\n" + chunk
        if current:
            if merged and len(current) < self.min_size:
                merged[-1] += "\n" + current
            else:
                merged.append(current)
        return merged

    def _link_chunks(self, chunks: list[PropositionChunk]) -> None:
        for i in range(len(chunks)):
            if i > 0:
                chunks[i].prev_chunk_id = chunks[i - 1].chunk_id
            if i < len(chunks) - 1:
                chunks[i].next_chunk_id = chunks[i + 1].chunk_id


if __name__ == "__main__":
    splitter = AdaptivePropositionSplitter()
    test_text = "第一章 总则\n第一条 本法适用于中华人民共和国境内的企业和其他取得收入的组织。\n第二条 本法所称企业，是指依法在中国境内成立的企业、事业单位、社会团体以及其他取得收入的组织。\n第二章 税率\n第三条 企业所得税的税率为25%。\n第四条 符合条件的小型微利企业，减按20%的税率征收企业所得税。"
    chunks = splitter.split(test_text, doc_type="law")
    logger.bind(count=len(chunks)).info("切分结果")
    for i, c in enumerate(chunks, 1):
        logger.bind(index=i, level=c.level, chunk_id=c.chunk_id, char_count=c.char_count).debug(c.text[:50])
