from sqlalchemy.orm import Session

from app.models import DocumentTag, KnowledgeDocument, KnowledgeTag
from common_logging import get_logger, log_execution

logger = get_logger(__name__)


@log_execution(logger)
def auto_tag_document_on_upload(
    db: Session,
    document_id: int,
    confidence_threshold: float = 0.5,
    max_tags: int = 5,
    min_tags: int = 3,
) -> list[int]:
    logger.bind(document_id=document_id).info("Starting auto-tag")
    document = db.query(KnowledgeDocument).filter(KnowledgeDocument.id == document_id).first()
    if not document:
        logger.warning(f"Document {document_id} not found")
        return []
    tags = (
        db.query(KnowledgeTag)
        .filter(
            KnowledgeTag.is_predefined,
            KnowledgeTag.tag_category_id.isnot(None),
            KnowledgeTag.status == "enabled",
        )
        .all()
    )
    logger.info(f"Found {len(tags)} predefined tags for auto-tagging")
    if not tags:
        logger.warning("No predefined tags found")
        return []
    title_text = document.title.lower()
    content_text = (document.content or "").lower()
    doc_text = f"{title_text} {content_text}"
    doc_length = len(document.title) + len(document.content or "")
    logger.info(f"Document length: {doc_length} chars, text preview: {doc_text[:200]}...")
    if doc_length < 500:
        effective_threshold = confidence_threshold * 0.6
        logger.info(f"Short document detected, lowering threshold to {effective_threshold:.2f}")
    else:
        effective_threshold = confidence_threshold
    tag_scores = []
    for tag in tags:
        keywords = tag.keywords if tag.keywords else [tag.name]
        logger.debug(f"Tag '{tag.name}' (ID:{tag.id}) keywords: {keywords}")
        matched_keywords = []
        title_matches = 0
        content_matches = 0
        for keyword in keywords:
            kw_lower = keyword.lower()
            title_count = title_text.count(kw_lower)
            content_count = content_text.count(kw_lower)
            if title_count > 0 or content_count > 0:
                matched_keywords.append(keyword)
                title_matches += title_count
                content_matches += content_count
                logger.debug(
                    f"  Matched keyword '{keyword}' - title: {title_count}, content: {content_count}"
                )
        if matched_keywords:
            base_confidence = min(len(matched_keywords) / len(keywords), 1.0)
            weighted_matches = title_matches * 2.0 + content_matches
            frequency_boost = min(weighted_matches / 10, 0.3)
            weight_factor = min(tag.search_weight or 1.0, 2.0) / 2.0
            confidence = min(base_confidence + frequency_boost, 1.0) * weight_factor
            logger.info(f"Tag '{tag.name}' matched with confidence {confidence:.2f}")
            if confidence >= effective_threshold:
                tag_scores.append(
                    {"tag": tag, "confidence": confidence, "category_id": tag.tag_category_id}
                )
    if not tag_scores and doc_length < 200 and document.category_id:
        logger.info(f"No tags matched for short document {document_id}, using fallback tags")
        fallback_tags = (
            db.query(KnowledgeTag)
            .filter(KnowledgeTag.is_predefined, KnowledgeTag.status == "enabled")
            .limit(2)
            .all()
        )
        for tag in fallback_tags:
            tag_scores.append({"tag": tag, "confidence": 0.4, "category_id": tag.tag_category_id})
        logger.info(f"Added {len(fallback_tags)} fallback tags")
    if not tag_scores:
        logger.info(f"No tags matched for document {document_id}")
        return []
    selected_tags = _select_balanced_tags(tag_scores, min_tags, max_tags)
    for tag_id in selected_tags:
        existing = (
            db.query(DocumentTag)
            .filter(DocumentTag.document_id == document_id, DocumentTag.tag_id == tag_id)
            .first()
        )
        if not existing:
            doc_tag = DocumentTag(document_id=document_id, tag_id=tag_id)
            db.add(doc_tag)
    db.commit()
    logger.info(f"Auto-tagged document {document_id} with {len(selected_tags)} tags")
    return selected_tags


def _select_balanced_tags(tag_scores: list[dict], min_tags: int, max_tags: int) -> list[int]:
    tag_scores.sort(key=lambda x: x["confidence"], reverse=True)
    by_category = {}
    for item in tag_scores:
        cat_id = item["category_id"]
        if cat_id not in by_category:
            by_category[cat_id] = []
        by_category[cat_id].append(item)
    selected = []
    for _cat_id, items in by_category.items():
        if items and len(selected) < max_tags:
            selected.append(items[0]["tag"].id)
    if len(selected) < min_tags:
        for item in tag_scores:
            if item["tag"].id not in selected and len(selected) < max_tags:
                selected.append(item["tag"].id)
            if len(selected) >= min_tags:
                break
    return selected[:max_tags]
