from __future__ import annotations

import uuid
from dataclasses import asdict, dataclass, field
from typing import Any

from common_logging import get_logger

logger = get_logger(__name__)


@dataclass
class ParentChunk:
    parent_id: str
    document_id: str
    text: str
    article_range: str
    child_ids: list[str] = field(default_factory=list)
    chapter: str | None = None
    metadata: dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> dict[str, Any]:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> ParentChunk:
        return cls(**data)


@dataclass
class ChildChunk:
    child_id: str
    parent_id: str
    document_id: str
    text: str
    level: str
    embedding: list[float] | None = None
    metadata: dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> dict[str, Any]:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> ChildChunk:
        return cls(**data)


class ParentChildStore:

    def __init__(self) -> None:
        self._parents: dict[str, ParentChunk] = {}
        self._child_to_parent: dict[str, str] = {}
        self._parent_to_children: dict[str, list[ChildChunk]] = {}

    def store_chunks(self, parent: ParentChunk, children: list[ChildChunk]) -> bool:
        try:
            child_ids = [c.child_id for c in children]
            parent.child_ids = child_ids
            self._parents[parent.parent_id] = parent
            self._parent_to_children[parent.parent_id] = list(children)
            for child in children:
                self._child_to_parent[child.child_id] = parent.parent_id
            logger.bind(doc_id=parent.document_id).info(
                "stored parent chunk with {child_count} children",
                child_count=len(children),
            )
            return True
        except Exception as e:
            logger.bind(doc_id=parent.document_id).warning(
                "failed to store chunks: {error}", error=str(e)
            )
            return False

    def get_parent(self, parent_id: str) -> ParentChunk | None:
        return self._parents.get(parent_id)

    def get_children(self, parent_id: str) -> list[ChildChunk]:
        return self._parent_to_children.get(parent_id, [])

    def get_parent_by_child(self, child_id: str) -> ParentChunk | None:
        parent_id = self._child_to_parent.get(child_id)
        if parent_id is None:
            return None
        return self._parents.get(parent_id)

    def expand_to_parent_context(self, child_ids: list[str]) -> list[ParentChunk]:
        seen: set = set()
        parents: list[ParentChunk] = []
        for child_id in child_ids:
            parent = self.get_parent_by_child(child_id)
            if parent is not None and parent.parent_id not in seen:
                seen.add(parent.parent_id)
                parents.append(parent)
        return parents

    def to_dict(self) -> dict[str, Any]:
        return {
            "parents": {pid: p.to_dict() for pid, p in self._parents.items()},
            "parent_to_children": {
                pid: [c.to_dict() for c in children]
                for pid, children in self._parent_to_children.items()
            },
        }

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> ParentChildStore:
        store = cls()
        for pid, parent_data in data.get("parents", {}).items():
            parent = ParentChunk.from_dict(parent_data)
            children = [
                ChildChunk.from_dict(c) for c in data.get("parent_to_children", {}).get(pid, [])
            ]
            store.store_chunks(parent, children)
        return store


class ParentChildSplitStrategy:

    def create_parent_child_pairs(
        self, chunks: list[Any], window_size: int = 3
    ) -> tuple[list[ParentChunk], list[ChildChunk]]:
        if not chunks:
            return ([], [])
        parents: list[ParentChunk] = []
        children: list[ChildChunk] = []
        for window_start in range(0, len(chunks), window_size):
            window = chunks[window_start : window_start + window_size]
            parent_id = str(uuid.uuid4())
            first = window[0]
            document_id: str = getattr(first, "document_id", "") or ""
            chapter: str | None = getattr(first, "chapter", None)
            window_children: list[ChildChunk] = []
            child_texts: list[str] = []
            for chunk in window:
                child_id = str(uuid.uuid4())
                text: str = getattr(chunk, "text", str(chunk))
                child_texts.append(text)
                child = ChildChunk(
                    child_id=child_id,
                    parent_id=parent_id,
                    document_id=getattr(chunk, "document_id", document_id),
                    text=text,
                    level=getattr(chunk, "level", "chunk"),
                    embedding=getattr(chunk, "embedding", None),
                    metadata=dict(getattr(chunk, "metadata", {}) or {}),
                )
                window_children.append(child)
            parent_text = "\n".join(child_texts)
            article_nos = [
                str(getattr(c, "article_no", "")) for c in window if getattr(c, "article_no", None)
            ]
            if article_nos:
                article_range = f"第{article_nos[0]}-{article_nos[-1]}条"
            else:
                start_idx = window_start + 1
                end_idx = window_start + len(window)
                article_range = f"块{start_idx}-{end_idx}"
            parent = ParentChunk(
                parent_id=parent_id,
                document_id=document_id,
                text=parent_text,
                article_range=article_range,
                child_ids=[c.child_id for c in window_children],
                chapter=chapter,
                metadata={},
            )
            parents.append(parent)
            children.extend(window_children)
        return (parents, children)
