import re

from common_logging import get_logger

logger = get_logger(__name__)


class TaxDocumentSplitter:
    CHAPTER_PATTERN = re.compile('第[一二三四五六七八九十百千]+章')
    ARTICLE_PATTERN = re.compile('第[一二三四五六七八九十百千]+条')
    CLAUSE_PATTERN = re.compile('第[一二三四五六七八九十百千]+款|（[一二三四五六七八九十]）|\\([一二三四五六七八九十]\\)')
    APPENDIX_PATTERN = re.compile('(附录|附件|附表)[:：]', re.IGNORECASE)
    TABLE_SEPARATOR_PATTERN = re.compile('\\|[\\s\\-:]+\\|')

    def __init__(self, granularity: str='article', add_context: bool=True, document_title: str | None=None, chunk_size: int=2000, chunk_overlap: int=100):
        self.granularity = granularity
        self.add_context = add_context
        self.document_title = document_title or '未命名文档'
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        if granularity == 'chapter':
            self.primary_pattern = self.CHAPTER_PATTERN
        elif granularity == 'article':
            self.primary_pattern = self.ARTICLE_PATTERN
        elif granularity == 'clause':
            self.primary_pattern = self.CLAUSE_PATTERN
        else:
            logger.warning(f"Unknown granularity: {granularity}, using 'article' as default")
            self.primary_pattern = self.ARTICLE_PATTERN

    def split_text(self, text: str) -> list[str]:
        try:
            main_text, appendix = self._extract_appendix(text)
            chunks = self._split_by_pattern(main_text)
            if appendix:
                chunks.append(appendix)
            if self.add_context:
                chunks = self._add_context_info(chunks, text)
            chunks = self._handle_oversized_chunks(chunks)
            logger.info(f'Tax document splitting completed: {len(text)} characters -> {len(chunks)} chunks (granularity: {self.granularity})')
            return chunks
        except Exception as e:
            logger.error(f'Tax document splitting failed: {e}')
            return [text]

    def split_texts_batch(self, texts: list[str]) -> list[list[str]]:
        results = []
        for text in texts:
            results.append(self.split_text(text))
        return results

    def _extract_appendix(self, text: str) -> tuple[str, str]:
        match = self.APPENDIX_PATTERN.search(text)
        if match:
            split_pos = match.start()
            return (text[:split_pos].strip(), text[split_pos:].strip())
        return (text, '')

    def _split_by_pattern(self, text: str) -> list[str]:
        chunks = []
        matches = list(self.primary_pattern.finditer(text))
        if not matches:
            return [text.strip()] if text.strip() else []
        if matches[0].start() > 0:
            preamble = text[:matches[0].start()].strip()
            if preamble:
                chunks.append(preamble)
        for i, match in enumerate(matches):
            start = match.start()
            end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
            chunk = text[start:end].strip()
            if self._is_table_boundary_safe(chunk):
                if chunk:
                    chunks.append(chunk)
            elif chunks:
                chunks[-1] = chunks[-1] + '\n\n' + chunk
            else:
                chunks.append(chunk)
        return chunks

    def _is_table_boundary_safe(self, text: str) -> bool:
        lines = text.split('\n')
        table_lines = [line for line in lines if '|' in line]
        if not table_lines:
            return True
        separator_lines = [line for line in table_lines if self.TABLE_SEPARATOR_PATTERN.search(line)]
        if len(table_lines) > 2 and (not separator_lines):
            return False
        return True

    def _add_context_info(self, chunks: list[str], full_text: str) -> list[str]:
        result = []
        current_chapter = None
        current_article = None
        for chunk in chunks:
            chapter_match = self.CHAPTER_PATTERN.search(chunk)
            if chapter_match:
                current_chapter = chapter_match.group()
            article_match = self.ARTICLE_PATTERN.search(chunk)
            if article_match:
                current_article = article_match.group()
            context_parts = [self.document_title]
            if current_chapter:
                context_parts.append(current_chapter)
            if current_article and self.granularity == 'clause':
                context_parts.append(current_article)
            context_label = f"🏷️ 来源：{' ᐳ '.join(context_parts)}\n\n"
            result.append(context_label + chunk)
        return result

    def _handle_oversized_chunks(self, chunks: list[str]) -> list[str]:
        result = []
        for chunk in chunks:
            if len(chunk) <= self.chunk_size:
                result.append(chunk)
            else:
                sub_chunks = self._split_large_chunk(chunk)
                result.extend(sub_chunks)
        return result

    def _split_large_chunk(self, chunk: str) -> list[str]:
        context_label = ''
        content = chunk
        if chunk.startswith('🏷️ 来源：'):
            parts = chunk.split('\n\n', 1)
            if len(parts) == 2:
                context_label = parts[0] + '\n\n'
                content = parts[1]
        paragraphs = content.split('\n\n')
        sub_chunks = []
        current_chunk = ''
        for para in paragraphs:
            if len(current_chunk) + len(para) + 2 <= self.chunk_size:
                current_chunk += para + '\n\n'
            else:
                if current_chunk:
                    sub_chunks.append(context_label + current_chunk.strip())
                current_chunk = para + '\n\n'
        if current_chunk:
            sub_chunks.append(context_label + current_chunk.strip())
        return sub_chunks if sub_chunks else [chunk]
