from typing import Any

from sqlalchemy import func
from sqlalchemy.orm import Session

from app.models import DocumentTag, KnowledgeCategory, KnowledgeTag

from common_logging import get_logger

logger = get_logger(__name__)



def get_all_child_category_ids(
    db: Session, parent_id: int, max_depth: int = 10, _current_depth: int = 0, _visited: set = None
) -> list[int]:
    if _visited is None:
        _visited = set()
    if _current_depth >= max_depth:
        logger.warning(f"Max recursion depth {max_depth} reached for category {parent_id}")
        return [parent_id]
    if parent_id in _visited:
        logger.warning(f"Circular reference detected for category {parent_id}")
        return []
    _visited.add(parent_id)
    child_ids = [parent_id]
    children = db.query(KnowledgeCategory).filter(KnowledgeCategory.parent_id == parent_id).all()
    for child in children:
        child_ids.extend(
            get_all_child_category_ids(db, child.id, max_depth, _current_depth + 1, _visited)
        )
    return child_ids


def get_all_child_tag_ids(
    db: Session, parent_id: int, max_depth: int = 10, _current_depth: int = 0, _visited: set = None
) -> list[int]:
    if _visited is None:
        _visited = set()
    if _current_depth >= max_depth:
        logger.warning(f"Max recursion depth {max_depth} reached for tag {parent_id}")
        return [parent_id]
    if parent_id in _visited:
        logger.warning(f"Circular reference detected for tag {parent_id}")
        return []
    _visited.add(parent_id)
    child_ids = [parent_id]
    children = db.query(KnowledgeTag).filter(KnowledgeTag.parent_id == parent_id).all()
    for child in children:
        child_ids.extend(
            get_all_child_tag_ids(db, child.id, max_depth, _current_depth + 1, _visited)
        )
    return child_ids


def validate_tag_parent(db: Session, tag_id: int, parent_id: int) -> bool:
    if parent_id is None:
        return True
    if tag_id == parent_id:
        return False
    child_ids = get_all_child_tag_ids(db, tag_id)
    if parent_id in child_ids:
        return False
    return True


def get_tag_document_count(db: Session, tag_id: int, include_children: bool = False) -> int:
    if include_children:
        tag_ids = get_all_child_tag_ids(db, tag_id)
        count = (
            db.query(func.count(func.distinct(DocumentTag.document_id)))
            .filter(DocumentTag.tag_id.in_(tag_ids))
            .scalar()
        )
    else:
        count = (
            db.query(func.count(DocumentTag.document_id))
            .filter(DocumentTag.tag_id == tag_id)
            .scalar()
        )
    return count or 0


def build_tag_tree(db: Session, knowledge_base_id: int = None) -> list[dict[str, Any]]:
    tags = db.query(KnowledgeTag).order_by(KnowledgeTag.sort_order, KnowledgeTag.id).all()
    tag_dict = {}
    for tag in tags:
        tag_dict[tag.id] = {
            "id": tag.id,
            "name": tag.name,
            "parent_id": tag.parent_id,
            "description": tag.description,
            "icon": tag.icon,
            "color": tag.color,
            "keywords": tag.keywords,
            "sort_order": tag.sort_order if tag.sort_order is not None else 0,
            "status": tag.status,
            "document_count": get_tag_document_count(db, tag.id, include_children=False),
            "created_at": tag.created_at,
            "updated_at": tag.updated_at,
            "children": [],
        }
    root_tags = []
    for tag in tags:
        if tag.parent_id is None:
            root_tags.append(tag_dict[tag.id])
        elif tag.parent_id in tag_dict:
            tag_dict[tag.parent_id]["children"].append(tag_dict[tag.id])
    return root_tags
