import re
from dataclasses import dataclass, field
from enum import Enum

from common_logging import get_logger

logger = get_logger(__name__)



class DocumentType(str, Enum):
    ANNOUNCEMENT = "announcement"
    NOTICE = "notice"
    REGULATION = "regulation"
    LAW = "law"
    CASE = "case"
    PRACTICAL = "practical"
    UNKNOWN = "unknown"


@dataclass
class DocumentStructure:
    chapters: int = 0
    sections: int = 0
    articles: int = 0
    clauses: int = 0
    has_table: bool = False
    has_appendix: bool = False


@dataclass
class DocumentClassification:
    doc_type: DocumentType = DocumentType.UNKNOWN
    doc_number: str | None = None
    effective_date: str | None = None
    issuing_authority: str | None = None
    structure: DocumentStructure = field(default_factory=DocumentStructure)
    keywords: list[str] = field(default_factory=list)


class TaxDocumentClassifier:
    _DOC_NUMBER_PATTERNS = [
        re.compile("[\\w\\u4e00-\\u9fa5]+[\\[\\[〔]\\d{4}[\\]\\]〕]\\d+号"),
        re.compile("[\\u4e00-\\u9fa5]{4,15}公告\\d{4}年第\\d+号"),
        re.compile("税总[\\w\\u4e00-\\u9fa5]*[\\[\\[〔]\\d{4}[\\]\\]〕]\\d+号"),
    ]
    _AUTHORITY_KEYWORDS = [
        "财政部",
        "国家税务总局",
        "海关总署",
        "国务院",
        "全国人民代表大会",
        "全国人大",
        "省税务局",
        "市税务局",
    ]
    _TYPE_RULES = [
        (DocumentType.ANNOUNCEMENT, ["公告", "通告"]),
        (DocumentType.NOTICE, ["通知", "函", "批复", "意见"]),
        (DocumentType.REGULATION, ["管理办法", "实施办法", "暂行办法", "细则", "规程", "规定"]),
        (DocumentType.LAW, ["法", "条例", "决定", "决议", "草案"]),
        (DocumentType.CASE, ["稽查", "案例", "违法", "处罚决定", "税务检查"]),
        (DocumentType.PRACTICAL, ["解读", "问答", "操作指南", "实务", "指引", "说明"]),
    ]
    _DATE_PATTERNS = [
        re.compile("自(\\d{4}年\\d{1,2}月\\d{1,2}日)起(施行|执行|生效)"),
        re.compile("(\\d{4}年\\d{1,2}月\\d{1,2}日)起(施行|执行|生效)"),
        re.compile("施行日期[：:]\\s*(\\d{4}年\\d{1,2}月\\d{1,2}日)"),
        re.compile("生效日期[：:]\\s*(\\d{4}年\\d{1,2}月\\d{1,2}日)"),
        re.compile("本[\\u4e00-\\u9fa5]{1,4}自(\\d{4}年\\d{1,2}月\\d{1,2}日)"),
    ]
    _CHAPTER_PATTERN = re.compile("第[一二三四五六七八九十百千]+章")
    _SECTION_PATTERN = re.compile("第[一二三四五六七八九十百千]+节")
    _ARTICLE_PATTERN = re.compile("第[一二三四五六七八九十百千]+条")
    _CLAUSE_PATTERN = re.compile(
        "第[一二三四五六七八九十百千]+款|（[一二三四五六七八九十]）|\\([一二三四五六七八九十]\\)"
    )
    _APPENDIX_PATTERN = re.compile("(附录|附件|附表)[：:]", re.IGNORECASE)
    _TABLE_PATTERN = re.compile("\\|[\\s\\-:]+\\||\\+-{2,}\\+")

    def classify(self, text: str, filename: str = "") -> DocumentClassification:
        combined = filename + "\n" + text[:3000]
        doc_type = self._classify_type(combined)
        doc_number = self._extract_doc_number(combined)
        issuing_authority = self._extract_authority(combined)
        effective_date = self._extract_effective_date(text)
        structure = self._analyze_structure(text)
        keywords = self._extract_keywords(combined)
        classification = DocumentClassification(
            doc_type=doc_type,
            doc_number=doc_number,
            effective_date=effective_date,
            issuing_authority=issuing_authority,
            structure=structure,
            keywords=keywords,
        )
        logger.bind(doc_id=filename, category=doc_type.value).info(
            "document classified"
        )
        return classification

    def _classify_type(self, text: str) -> DocumentType:
        for doc_type, keywords in self._TYPE_RULES:
            for kw in keywords:
                if kw in text:
                    return doc_type
        return DocumentType.UNKNOWN

    def _extract_doc_number(self, text: str) -> str | None:
        for pattern in self._DOC_NUMBER_PATTERNS:
            m = pattern.search(text)
            if m:
                return m.group(0)
        return None

    def _extract_authority(self, text: str) -> str | None:
        found = []
        for authority in self._AUTHORITY_KEYWORDS:
            if authority in text:
                found.append(authority)
        if not found:
            return None
        return "/".join(found[:3])

    def _extract_effective_date(self, text: str) -> str | None:
        for pattern in self._DATE_PATTERNS:
            m = pattern.search(text)
            if m:
                return m.group(1)
        return None

    def _analyze_structure(self, text: str) -> DocumentStructure:
        return DocumentStructure(
            chapters=len(self._CHAPTER_PATTERN.findall(text)),
            sections=len(self._SECTION_PATTERN.findall(text)),
            articles=len(self._ARTICLE_PATTERN.findall(text)),
            clauses=len(self._CLAUSE_PATTERN.findall(text)),
            has_table=bool(self._TABLE_PATTERN.search(text)),
            has_appendix=bool(self._APPENDIX_PATTERN.search(text)),
        )

    def _extract_keywords(self, text: str) -> list[str]:
        tax_keywords = [
            "增值税",
            "企业所得税",
            "个人所得税",
            "消费税",
            "印花税",
            "土地增值税",
            "房产税",
            "车辆购置税",
            "资源税",
            "环境保护税",
            "关税",
            "契税",
            "耕地占用税",
            "城市维护建设税",
            "小规模纳税人",
            "一般纳税人",
            "发票",
            "抵扣",
            "退税",
            "免税",
            "纳税申报",
            "税收优惠",
            "减免税",
            "出口退税",
        ]
        return [kw for kw in tax_keywords if kw in text]


if __name__ == "__main__":
    classifier = TaxDocumentClassifier()
    test_text = "财政部 税务总局公告2024年第9号\n关于增值税小规模纳税人减免增值税政策的公告"
    result = classifier.classify(test_text)
    logger.info(str(result))
