from dataclasses import dataclass
from typing import Any

from .parent_child_store import ParentChildSplitStrategy
from .proposition_splitter import AdaptivePropositionSplitter
from .reference_extractor import TaxReferenceExtractor

from common_logging import get_logger

logger = get_logger(__name__)



@dataclass
class TaxChunk:
    chunk_id: str
    text: str
    is_parent: bool
    parent_chunk_id: str | None
    chunk_level: str
    chunk_index: int
    references: list[dict[str, Any]]
    metadata: dict[str, Any]


class TaxAdaptiveSplitter:

    def __init__(
        self,
        window_size: int = 4,
        min_child_size: int = 500,
        max_child_size: int = 800,
        target_parent_size: int = 2000,
    ):
        self.window_size = window_size
        self.min_child_size = min_child_size
        self.max_child_size = max_child_size
        self.target_parent_size = target_parent_size
        self.proposition_splitter = AdaptivePropositionSplitter(
            min_size=min_child_size,
            max_size=max_child_size,
            target_size=(min_child_size + max_child_size) // 2,
        )
        self.parent_child_strategy = ParentChildSplitStrategy()
        self.reference_extractor = TaxReferenceExtractor()

    def split(
        self, text: str, document_id: str, document_number: str = "", doc_type: str = "law"
    ) -> list[TaxChunk]:
        if not text or not text.strip():
            logger.warning(f"Empty text for document {document_id}")
            return []
        logger.info(f"Splitting document {document_id} with type {doc_type}")
        proposition_chunks = self.proposition_splitter.split(text, doc_type=doc_type)
        if not proposition_chunks:
            logger.warning(f"No chunks generated for document {document_id}")
            return []
        logger.info(f"Generated {len(proposition_chunks)} proposition chunks")
        doc_references = self.reference_extractor.extract_relations(
            text=text, doc_id=document_id, doc_number=document_number
        )
        logger.info(f"Extracted {len(doc_references)} document-level references")
        parents, children = self.parent_child_strategy.create_parent_child_pairs(
            chunks=proposition_chunks, window_size=self.window_size
        )
        logger.info(f"Created {len(parents)} parent chunks and {len(children)} child chunks")
        child_references_map = {}
        for child in children:
            refs = self._extract_chunk_references(child.text, document_id, document_number)
            if refs:
                child_references_map[child.child_id] = refs
        tax_chunks = []
        for idx, parent in enumerate(parents):
            parent_refs = self._merge_references(
                [child_references_map.get(cid, []) for cid in parent.child_ids]
            )
            tax_chunk = TaxChunk(
                chunk_id=parent.parent_id,
                text=parent.text,
                is_parent=True,
                parent_chunk_id=None,
                chunk_level="parent",
                chunk_index=idx,
                references=parent_refs,
                metadata={
                    "document_id": parent.document_id,
                    "article_range": parent.article_range,
                    "chapter": parent.chapter,
                    "child_ids": parent.child_ids,
                    "char_count": len(parent.text),
                    **parent.metadata,
                },
            )
            tax_chunks.append(tax_chunk)
        parent_count = len(parents)
        for idx, child in enumerate(children):
            refs = child_references_map.get(child.child_id, [])
            tax_chunk = TaxChunk(
                chunk_id=child.child_id,
                text=child.text,
                is_parent=False,
                parent_chunk_id=child.parent_id,
                chunk_level=child.level,
                chunk_index=parent_count + idx,
                references=refs,
                metadata={
                    "document_id": child.document_id,
                    "level": child.level,
                    "char_count": len(child.text),
                    **child.metadata,
                },
            )
            tax_chunks.append(tax_chunk)
        logger.info(
            f"Generated {len(tax_chunks)} total TaxChunks ({len(parents)} parents + {len(children)} children)"
        )
        return tax_chunks

    def _extract_chunk_references(
        self, chunk_text: str, document_id: str, document_number: str
    ) -> list[dict[str, Any]]:
        relations = self.reference_extractor.extract_relations(
            text=chunk_text, doc_id=document_id, doc_number=document_number
        )
        return [
            {
                "relation_type": rel.relation_type.value,
                "target_doc_number": rel.target_doc_number,
                "article_number": rel.article_number,
                "context": rel.context,
                "confidence": rel.confidence,
            }
            for rel in relations
        ]

    def _merge_references(
        self, reference_lists: list[list[dict[str, Any]]]
    ) -> list[dict[str, Any]]:
        merged_map: dict[tuple, dict[str, Any]] = {}
        for ref_list in reference_lists:
            for ref in ref_list:
                key = (
                    ref.get("relation_type"),
                    ref.get("target_doc_number"),
                    ref.get("article_number"),
                )
                if key in merged_map:
                    if ref.get("confidence", 0) > merged_map[key].get("confidence", 0):
                        merged_map[key] = ref
                else:
                    merged_map[key] = ref
        return list(merged_map.values())
