from __future__ import annotations

import re
from dataclasses import dataclass
from enum import Enum

from common_logging import get_logger

logger = get_logger(__name__)

_CN_BRACKET = "[\\[\\]\\[\\]〔〕（()]"
DOC_NUMBER_PATTERNS = [
    re.compile("[\\u4e00-\\u9fa5]{2,6}(?:[\\[\\[〔（(][12]\\d{3}[\\]\\]〕)][^号]{1,6}号)"),
    re.compile("[\\u4e00-\\u9fa5]{2,8}第?(?:[\\[\\[〔（(][12]\\d{3}[\\]\\]〕)][^号]{1,6}号)"),
    re.compile("(?:国家税务总局|财政部|海关总署)[^公]*公告[12]\\d{3}年第\\d{1,4}号"),
    re.compile("(?:国务院|财政部|税务总局)令第\\d{1,4}号"),
]
_DOC_NUMBER_RE = re.compile(
    "(?:[\\u4e00-\\u9fa5]{2,8}第?[\\[\\[〔（(][12]\\d{3}[\\]\\]〕)][^号]{1,6}号|(?:国家税务总局|财政部|海关总署)[^公]*公告[12]\\d{3}年第\\d{1,4}号|(?:国务院|财政部|税务总局)令第\\d{1,4}号)"
)
_ARTICLE_RE = re.compile("第[一二三四五六七八九十百\\d]+条(?:第[一二三四五六七八九十百\\d]+款)?")
_CTX_MAX = 100


class RelationType(str, Enum):
    REFERENCES = "REFERENCES"
    ANNULS = "ANNULS"
    AMENDS = "AMENDS"
    HAS_ANNEX = "HAS_ANNEX"
    ISSUED_BY = "ISSUED_BY"
    SUPERSEDED_BY = "SUPERSEDED_BY"
    EFFECTIVE_FROM = "EFFECTIVE_FROM"


@dataclass
class DocumentRelation:
    source_doc_id: str
    relation_type: RelationType
    context: str
    confidence: float
    target_doc_id: str | None = None
    target_doc_number: str | None = None
    article_number: str | None = None


class TaxReferenceExtractor:

    def extract_relations(
        self, text: str, doc_id: str, doc_number: str = ""
    ) -> list[DocumentRelation]:
        results: list[DocumentRelation] = []
        results.extend(self._extract_references(text, doc_id))
        results.extend(self._extract_annuls(text, doc_id))
        results.extend(self._extract_amendments(text, doc_id))
        results.extend(self._extract_annexes(text, doc_id))
        if doc_number:
            results = [r for r in results if r.target_doc_number != doc_number]
        seen: set = set()
        deduped: list[DocumentRelation] = []
        for r in results:
            key = (r.relation_type, r.target_doc_number, r.article_number)
            if key not in seen:
                seen.add(key)
                deduped.append(r)
        logger.debug("reference_extractor: doc_id=%s extracted %d relations", doc_id, len(deduped))
        return deduped

    def _extract_references(self, text: str, doc_id: str) -> list[DocumentRelation]:
        relations: list[DocumentRelation] = []
        ref_trigger = re.compile(
            "(?:依据|根据|按照|依照|遵照)([^，。；\\n]{0,60}?)("
            + _DOC_NUMBER_RE.pattern
            + ")([^，。；\\n]{0,30}?)("
            + _ARTICLE_RE.pattern
            + ")?",
            re.UNICODE,
        )
        for m in ref_trigger.finditer(text):
            doc_num = m.group(2)
            article = m.group(4)
            ctx = _snippet(text, m.start(), m.end())
            relations.append(
                DocumentRelation(
                    source_doc_id=doc_id,
                    relation_type=RelationType.REFERENCES,
                    target_doc_number=doc_num,
                    article_number=article,
                    context=ctx,
                    confidence=0.9,
                )
            )
        for m in _DOC_NUMBER_RE.finditer(text):
            doc_num = m.group(0)
            suffix = text[m.end() : m.end() + 30]
            art_m = _ARTICLE_RE.search(suffix)
            article = art_m.group(0) if art_m else None
            ctx = _snippet(text, m.start(), m.end() + (art_m.end() if art_m else 0))
            if not any(r.target_doc_number == doc_num and r.confidence >= 0.9 for r in relations):
                relations.append(
                    DocumentRelation(
                        source_doc_id=doc_id,
                        relation_type=RelationType.REFERENCES,
                        target_doc_number=doc_num,
                        article_number=article,
                        context=ctx,
                        confidence=0.7,
                    )
                )
        see_trigger = re.compile(
            "(?:参见|详见|见)([^，。；\\n]{0,10}?)(" + _DOC_NUMBER_RE.pattern + ")", re.UNICODE
        )
        for m in see_trigger.finditer(text):
            doc_num = m.group(2)
            ctx = _snippet(text, m.start(), m.end())
            relations.append(
                DocumentRelation(
                    source_doc_id=doc_id,
                    relation_type=RelationType.REFERENCES,
                    target_doc_number=doc_num,
                    context=ctx,
                    confidence=0.85,
                )
            )
        return relations

    def _extract_annuls(self, text: str, doc_id: str) -> list[DocumentRelation]:
        relations: list[DocumentRelation] = []
        annul_re = re.compile(
            "(?:废止|宣布失效|不再执行|停止执行|废除)[^，。；\\n]{0,20}?("
            + _DOC_NUMBER_RE.pattern
            + ")",
            re.UNICODE,
        )
        for m in annul_re.finditer(text):
            doc_num = m.group(1)
            ctx = _snippet(text, m.start(), m.end())
            relations.append(
                DocumentRelation(
                    source_doc_id=doc_id,
                    relation_type=RelationType.ANNULS,
                    target_doc_number=doc_num,
                    context=ctx,
                    confidence=0.92,
                )
            )
        list_annul_re = re.compile(
            "(?:同时废止|一并废止|予以废止)(?:以下)?(?:文件|规定|通知|公告)?[：:]([\\s\\S]{0,500}?)(?=\\n\\n|\\Z|本(?:通知|公告|办法))",
            re.UNICODE,
        )
        for m in list_annul_re.finditer(text):
            block = m.group(1)
            for num_m in _DOC_NUMBER_RE.finditer(block):
                doc_num = num_m.group(0)
                ctx = _snippet(block, num_m.start(), num_m.end())
                relations.append(
                    DocumentRelation(
                        source_doc_id=doc_id,
                        relation_type=RelationType.ANNULS,
                        target_doc_number=doc_num,
                        context=ctx,
                        confidence=0.88,
                    )
                )
        return relations

    def _extract_amendments(self, text: str, doc_id: str) -> list[DocumentRelation]:
        relations: list[DocumentRelation] = []
        amend_re = re.compile(
            "(?:修改|调整|修订|补充修改)[^，。；\\n]{0,20}?("
            + _DOC_NUMBER_RE.pattern
            + ")([^，。；\\n]{0,30}?)("
            + _ARTICLE_RE.pattern
            + ")?",
            re.UNICODE,
        )
        for m in amend_re.finditer(text):
            doc_num = m.group(1)
            article = m.group(3)
            ctx = _snippet(text, m.start(), m.end())
            relations.append(
                DocumentRelation(
                    source_doc_id=doc_id,
                    relation_type=RelationType.AMENDS,
                    target_doc_number=doc_num,
                    article_number=article,
                    context=ctx,
                    confidence=0.88,
                )
            )
        return relations

    def _extract_annexes(self, text: str, doc_id: str) -> list[DocumentRelation]:
        relations: list[DocumentRelation] = []
        annex_re = re.compile(
            "附件[\\s]*[1234567890一二三四五六七八九十]?[\\s]*[：:]([^\\n]{1,80})", re.UNICODE
        )
        skip_re = re.compile("^[\\d一二三四五六七八九十]+[\\.、）\\)]*\\s*$")
        for m in annex_re.finditer(text):
            annex_title = m.group(1).strip()
            if skip_re.match(annex_title) or len(annex_title) < 3:
                continue
            ctx = _snippet(text, m.start(), m.end())
            relations.append(
                DocumentRelation(
                    source_doc_id=doc_id,
                    relation_type=RelationType.HAS_ANNEX,
                    target_doc_number=None,
                    target_doc_id=annex_title,
                    context=ctx,
                    confidence=0.95,
                )
            )
        return relations

    def _extract_doc_numbers(self, text: str) -> list[str]:
        seen: set = set()
        result: list[str] = []
        for m in _DOC_NUMBER_RE.finditer(text):
            num = m.group(0)
            if num not in seen:
                seen.add(num)
                result.append(num)
        return result


def _snippet(text: str, start: int, end: int) -> str:
    raw = text[max(0, start) : min(len(text), end)]
    if len(raw) > _CTX_MAX:
        return raw[:_CTX_MAX]
    return raw


tax_reference_extractor = TaxReferenceExtractor()
