from typing import Any

from sqlalchemy.orm import Session

from common_logging import get_logger

logger = get_logger(__name__)




class SlidingWindowRetriever:

    def __init__(self, db: Session, knowledge_base_id: int, tenant_id: int):
        self.db = db
        self.knowledge_base_id = knowledge_base_id
        self.tenant_id = tenant_id

    def get_expanded_context(self, chunk_id: str, window_size: int = 2) -> str:
        from app.models.knowledge_base import DocumentVector

        target = self.db.query(DocumentVector).filter(DocumentVector.milvus_id == chunk_id).first()
        if not target:
            logger.warning(f"Chunk not found: {chunk_id}")
            return ""
        context_chunks = [target]
        current = target
        for _ in range(window_size):
            prev_id = current.prev_chunk_id if hasattr(current, "prev_chunk_id") else None
            if not prev_id:
                break
            prev = self.db.query(DocumentVector).filter(DocumentVector.milvus_id == prev_id).first()
            if not prev:
                break
            context_chunks.insert(0, prev)
            current = prev
        current = target
        for _ in range(window_size):
            next_id = current.next_chunk_id if hasattr(current, "next_chunk_id") else None
            if not next_id:
                break
            nxt = self.db.query(DocumentVector).filter(DocumentVector.milvus_id == next_id).first()
            if not nxt:
                break
            context_chunks.append(nxt)
            current = nxt
        return "\n\n".join(c.chunk_text for c in context_chunks)

    def get_neighbors(self, chunk_id: str, window_size: int = 1) -> list[dict[str, Any]]:
        from app.models.knowledge_base import DocumentVector


        target = self.db.query(DocumentVector).filter(DocumentVector.milvus_id == chunk_id).first()
        if not target:
            return []
        neighbors = []
        current = target
        for _ in range(window_size):
            prev_id = current.prev_chunk_id if hasattr(current, "prev_chunk_id") else None
            if not prev_id:
                break
            prev = self.db.query(DocumentVector).filter(DocumentVector.milvus_id == prev_id).first()
            if not prev:
                break
            neighbors.insert(
                0,
                {
                    "id": prev.milvus_id,
                    "text": prev.chunk_text,
                    "score": 0.0,
                    "source": "sliding_window",
                    "direction": "prev",
                },
            )
            current = prev
        current = target
        for _ in range(window_size):
            next_id = current.next_chunk_id if hasattr(current, "next_chunk_id") else None
            if not next_id:
                break
            nxt = self.db.query(DocumentVector).filter(DocumentVector.milvus_id == next_id).first()
            if not nxt:
                break
            neighbors.append(
                {
                    "id": nxt.milvus_id,
                    "text": nxt.chunk_text,
                    "score": 0.0,
                    "source": "sliding_window",
                    "direction": "next",
                }
            )
            current = nxt
        return neighbors


def get_sliding_window_retriever(
    db: Session, knowledge_base_id: int, tenant_id: int
) -> SlidingWindowRetriever:
    return SlidingWindowRetriever(db, knowledge_base_id, tenant_id)
