from sqlalchemy.orm import Session

from app.crud.base import CRUDBase
from app.models.knowledge_base import DocumentMetadataValue, KnowledgeMetadataField
from app.schemas.metadata import MetadataFieldCreate, MetadataFieldUpdate
from common_logging import get_logger

logger = get_logger(__name__)


class CRUDMetadataField(CRUDBase[KnowledgeMetadataField, MetadataFieldCreate, MetadataFieldUpdate]):

    def get_by_knowledge_base(
        self, db: Session, *, knowledge_base_id: int
    ) -> list[KnowledgeMetadataField]:
        return (
            db.query(self.model)
            .filter(KnowledgeMetadataField.knowledge_base_id == knowledge_base_id)
            .order_by(KnowledgeMetadataField.sort_order)
            .all()
        )

    def get_by_field_key(
        self, db: Session, *, knowledge_base_id: int, field_key: str
    ) -> KnowledgeMetadataField | None:
        return (
            db.query(self.model)
            .filter(
                KnowledgeMetadataField.knowledge_base_id == knowledge_base_id,
                KnowledgeMetadataField.field_key == field_key,
            )
            .first()
        )

    def create_with_knowledge_base(
        self, db: Session, *, obj_in: MetadataFieldCreate, knowledge_base_id: int
    ) -> KnowledgeMetadataField:
        max_order = (
            db.query(KnowledgeMetadataField.sort_order)
            .filter(KnowledgeMetadataField.knowledge_base_id == knowledge_base_id)
            .order_by(KnowledgeMetadataField.sort_order.desc())
            .first()
        )
        next_order = max_order[0] + 1 if max_order else 0
        db_obj = KnowledgeMetadataField(
            knowledge_base_id=knowledge_base_id,
            field_name=obj_in.field_name,
            field_key=obj_in.field_key,
            field_type=obj_in.field_type,
            field_options=obj_in.field_options,
            default_value=obj_in.default_value,
            is_required=obj_in.is_required,
            description=obj_in.description,
            sort_order=next_order,
        )
        db.add(db_obj)
        db.commit()
        db.refresh(db_obj)
        logger.bind(metadata_field_id=db_obj.id).info("Metadata field created")
        return db_obj

    def batch_reorder(self, db: Session, *, reorder_data: list[dict[str, int]]) -> bool:
        for item in reorder_data:
            db.query(self.model).filter(self.model.id == item["field_id"]).update(
                {"sort_order": item["sort_order"]}
            )
        db.commit()
        logger.info("Metadata fields reordered")
        return True


class CRUDDocumentMetadataValue(CRUDBase[DocumentMetadataValue, None, None]):

    def get_by_document(self, db: Session, *, document_id: int) -> list[DocumentMetadataValue]:
        return db.query(self.model).filter(DocumentMetadataValue.document_id == document_id).all()

    def get_by_document_and_field(
        self, db: Session, *, document_id: int, field_id: int
    ) -> DocumentMetadataValue | None:
        return (
            db.query(self.model)
            .filter(
                DocumentMetadataValue.document_id == document_id,
                DocumentMetadataValue.field_id == field_id,
            )
            .first()
        )

    def set_value(
        self, db: Session, *, document_id: int, field_id: int, value: str | None
    ) -> DocumentMetadataValue:
        db_obj = self.get_by_document_and_field(db, document_id=document_id, field_id=field_id)
        if db_obj:
            db_obj.value = value
        else:
            db_obj = DocumentMetadataValue(document_id=document_id, field_id=field_id, value=value)
            db.add(db_obj)
        db.commit()
        db.refresh(db_obj)
        logger.bind(document_id=document_id, field_id=field_id).info("Document metadata value set")
        return db_obj

    def delete_by_document_and_field(self, db: Session, *, document_id: int, field_id: int) -> bool:
        db_obj = self.get_by_document_and_field(db, document_id=document_id, field_id=field_id)
        if db_obj:
            db.delete(db_obj)
            db.commit()
            logger.bind(document_id=document_id, field_id=field_id).info("Document metadata value deleted")
            return True
        return False

    def delete_by_field(self, db: Session, *, field_id: int) -> int:
        count = db.query(self.model).filter(DocumentMetadataValue.field_id == field_id).delete()
        db.commit()
        logger.bind(field_id=field_id).info("Document metadata values deleted by field")
        return count

    def batch_update_document_metadata(
        self, db: Session, *, document_id: int, metadata: dict[int, str | None]
    ) -> list[DocumentMetadataValue]:
        results = []
        for field_id, value in metadata.items():
            if value is None or value == "":
                self.delete_by_document_and_field(db, document_id=document_id, field_id=field_id)
            else:
                result = self.set_value(db, document_id=document_id, field_id=field_id, value=value)
                results.append(result)
        logger.bind(document_id=document_id).info("Document metadata batch updated")
        return results


metadata_field = CRUDMetadataField(KnowledgeMetadataField)
document_metadata_value = CRUDDocumentMetadataValue(DocumentMetadataValue)
