import os
from collections.abc import Iterator
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

from common_logging import get_logger

logger = get_logger(__name__)


@dataclass
class BatchResult:
    total: int = 0
    success: int = 0
    failed: int = 0
    skipped: int = 0
    errors: list[str] = field(default_factory=list)

    @property
    def success_rate(self) -> float:
        return self.success / self.total if self.total > 0 else 0.0


class TaxDocumentBatchPipeline:
    CORPUS_CONFIG = [
        {"path": "/Users/brian/Downloads/总局2026", "hint": "announcement", "label": "2026法规"},
        {"path": "/Users/brian/Downloads/总局2025", "hint": "announcement", "label": "2025法规"},
        {"path": "/Users/brian/Downloads/2024", "hint": "law", "label": "2024法规"},
        {"path": "/Users/brian/Downloads/纳税评估", "hint": "case", "label": "纳税评估"},
        {"path": "/Users/brian/Downloads/稽查案例", "hint": "case", "label": "稽查案例"},
        {"path": "/Users/brian/Downloads/实务解读", "hint": "practical", "label": "实务解读"},
    ]

    def __init__(
        self, dry_run: bool = False, batch_size: int = 10, max_docs_per_corpus: int = None
    ):
        self.dry_run = dry_run
        self.batch_size = batch_size
        self.max_docs_per_corpus = max_docs_per_corpus
        self._init_services()

    def _init_services(self):
        import sys

        svc_path = os.path.join(os.path.dirname(__file__))
        if svc_path not in sys.path:
            sys.path.insert(0, svc_path)
        from document_classifier import TaxDocumentClassifier
        from parent_child_store import ParentChildSplitStrategy, ParentChildStore
        from proposition_splitter import AdaptivePropositionSplitter
        from reference_extractor import TaxReferenceExtractor

        self.classifier = TaxDocumentClassifier()
        self.splitter = AdaptivePropositionSplitter()
        self.extractor = TaxReferenceExtractor()
        self.store = ParentChildStore()
        self.split_strategy = ParentChildSplitStrategy()

    def scan_corpus(self) -> dict[str, int]:
        stats = {}
        for cfg in self.CORPUS_CONFIG:
            path = Path(cfg["path"])
            if path.exists():
                md_files = list(path.rglob("*.md"))
                txt_files = list(path.rglob("*.txt"))
                stats[cfg["label"]] = len(md_files) + len(txt_files)
            else:
                stats[cfg["label"]] = 0
        return stats

    def iter_documents(self, corpus_path: str, hint: str) -> Iterator[dict[str, Any]]:
        path = Path(corpus_path)
        if not path.exists():
            logger.warning(f"Corpus path not found: {corpus_path}")
            return
        for file_path in sorted(path.rglob("*.md")):
            try:
                text = file_path.read_text(encoding="utf-8", errors="ignore")
                if len(text.strip()) < 50:
                    continue
                yield {
                    "file_path": str(file_path),
                    "text": text,
                    "hint": hint,
                    "doc_id": file_path.stem,
                }
            except Exception as e:
                logger.warning(f"Failed to read {file_path}: {e}")
        for file_path in sorted(path.rglob("*.txt")):
            try:
                text = file_path.read_text(encoding="utf-8", errors="ignore")
                if len(text.strip()) < 50:
                    continue
                yield {
                    "file_path": str(file_path),
                    "text": text,
                    "hint": hint,
                    "doc_id": file_path.stem,
                }
            except Exception as e:
                logger.warning(f"Failed to read {file_path}: {e}")

    def process_document(self, doc: dict[str, Any]) -> dict[str, Any]:
        text = doc["text"]
        doc_id = doc["doc_id"]
        classification = self.classifier.classify(text)
        doc_type = classification.doc_type.value
        chunks = self.splitter.split(text, doc_type=doc_type)
        parents, children = self.split_strategy.create_parent_child_pairs(chunks)
        relations = self.extractor.extract_relations(text, doc_id=doc_id)
        return {
            "doc_id": doc_id,
            "file_path": doc["file_path"],
            "doc_type": doc_type,
            "doc_number": classification.doc_number,
            "issuing_authority": classification.issuing_authority,
            "chunk_count": len(chunks),
            "parent_count": len(parents),
            "child_count": len(children),
            "relation_count": len(relations),
            "chunks": chunks,
            "parents": parents,
            "children": children,
            "relations": relations,
        }

    def run(self, dry_run: bool | None = None) -> BatchResult:
        if dry_run is not None:
            self.dry_run = dry_run
        result = BatchResult()
        for cfg in self.CORPUS_CONFIG:
            logger.info(f"Processing corpus: {cfg['label']}")
            corpus_result = self._process_corpus(cfg)
            result.total += corpus_result.total
            result.success += corpus_result.success
            result.failed += corpus_result.failed
            result.skipped += corpus_result.skipped
            result.errors.extend(corpus_result.errors)
        logger.info(
            f"Batch complete: {result.success}/{result.total} success ({result.success_rate:.1%}), {result.failed} failed"
        )
        return result

    def _process_corpus(self, cfg: dict) -> BatchResult:
        result = BatchResult()
        batch = []
        for doc in self.iter_documents(cfg["path"], cfg["hint"]):
            if self.max_docs_per_corpus and result.total >= self.max_docs_per_corpus:
                break
            batch.append(doc)
            result.total += 1
            if len(batch) >= self.batch_size:
                self._process_batch(batch, result)
                batch = []
        if batch:
            self._process_batch(batch, result)
        logger.info(f"{cfg['label']}: {result.success}/{result.total} processed")
        return result

    def _process_batch(self, batch: list[dict], result: BatchResult):
        for doc in batch:
            try:
                processed = self.process_document(doc)
                if not self.dry_run:
                    self._store_document(processed)
                result.success += 1
                logger.debug(
                    f"OK {processed['doc_id']}: {processed['chunk_count']} chunks, {processed['relation_count']} relations"
                )
            except Exception as e:
                result.failed += 1
                msg = f"FAIL {doc['doc_id']}: {e}"
                result.errors.append(msg)
                logger.warning(msg)

    def _store_document(self, processed: dict):
        pass


if __name__ == "__main__":
    import sys

    dry_run = "--dry-run" in sys.argv
    pipeline = TaxDocumentBatchPipeline(dry_run=dry_run, batch_size=20)
    logger.info("扫描文档库...")
    stats = pipeline.scan_corpus()
    total = sum(stats.values())
    for label, count in stats.items():
        logger.bind(label=label, count=count).info("文档库统计")
    logger.bind(total=total).info("文档库总计")
    if "--scan-only" not in sys.argv:
        logger.info(f"{'[DRY RUN] ' if dry_run else ''}开始批量处理...")
        result = pipeline.run()
        logger.bind(success=result.success, total=result.total, rate=f"{result.success_rate:.1%}", failed=result.failed).info("处理结果")
        for err in result.errors[:5]:
            logger.warning(err)
