from datetime import datetime
from pathlib import Path

from fastapi import APIRouter, Depends, HTTPException
from common_logging import get_logger
from sqlalchemy.orm import Session

from app.api.deps import get_current_user
from app.config import settings
from app.db.session import get_db
from app.models.local_model import LocalModel
from app.models.provider import Model, ModelProvider
from app.models.user import User
from app.schemas.local_model import LocalModelCreate, LocalModelResponse, LocalModelUpdate

logger = get_logger(__name__)
router = APIRouter()


@router.get("", response_model=list[LocalModelResponse])
def get_local_models(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
    models = db.query(LocalModel).filter(LocalModel.tenant_id == current_user.tenant_id).all()
    return models


@router.post("", response_model=LocalModelResponse)
def create_local_model(
    model_data: LocalModelCreate,
    db: Session = Depends(get_db),
    current_user: User = Depends(get_current_user),
):
    logger.info(f"Creating local model: name={model_data.name}, user={current_user.id}")
    model = LocalModel(**model_data.model_dump(), tenant_id=current_user.tenant_id)
    db.add(model)
    db.commit()
    db.refresh(model)
    logger.info(f"Local model created: id={model.id}, identifier={model.identifier}")
    return model


@router.put("/{model_id}", response_model=LocalModelResponse)
def update_local_model(
    model_id: int,
    model_data: LocalModelUpdate,
    db: Session = Depends(get_db),
    current_user: User = Depends(get_current_user),
):
    model = (
        db.query(LocalModel)
        .filter(LocalModel.id == model_id, LocalModel.tenant_id == current_user.tenant_id)
        .first()
    )
    if not model:
        raise HTTPException(status_code=404, detail="Model not found")
    for key, value in model_data.model_dump(exclude_unset=True).items():
        setattr(model, key, value)
    db.commit()
    db.refresh(model)
    return model


@router.post("/{model_id}/publish")
def publish_local_model(
    model_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
):
    logger.info(f"Publishing local model: id={model_id}, user={current_user.id}")
    local_model = (
        db.query(LocalModel)
        .filter(LocalModel.id == model_id, LocalModel.tenant_id == current_user.tenant_id)
        .first()
    )
    if not local_model:
        logger.warning(f"Local model not found: id={model_id}")
        raise HTTPException(status_code=404, detail="Local model not found")
    if local_model.published_model_id:
        logger.warning(f"Model already published: id={model_id}")
        raise HTTPException(status_code=400, detail="Model already published")
    local_provider = (
        db.query(ModelProvider).filter(ModelProvider.provider_kind == "local_mlx").first()
    )
    if not local_provider:
        logger.error("Local MLX provider not found")
        raise HTTPException(status_code=500, detail="Local MLX provider not found")
    published_model = Model(
        provider_id=local_provider.id,
        code=local_model.identifier,
        name=local_model.name,
        type="chat",
        enabled=True,
        remote_model_id=local_model.identifier,
        supports_stream=True,
        supports_tools=False,
        priority=50,
    )
    db.add(published_model)
    db.flush()
    local_model.published_model_id = published_model.id
    db.commit()
    db.refresh(local_model)
    logger.info(f"Local model published: id={model_id}, published_model_id={published_model.id}")
    return {"success": True, "published_model_id": published_model.id}


@router.post("/{model_id}/unpublish")
def unpublish_local_model(
    model_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
):
    local_model = (
        db.query(LocalModel)
        .filter(LocalModel.id == model_id, LocalModel.tenant_id == current_user.tenant_id)
        .first()
    )
    if not local_model:
        raise HTTPException(status_code=404, detail="Local model not found")
    if not local_model.published_model_id:
        raise HTTPException(status_code=400, detail="Model not published")
    published_model = db.query(Model).filter(Model.id == local_model.published_model_id).first()
    if published_model:
        db.delete(published_model)
    local_model.published_model_id = None
    db.commit()
    db.refresh(local_model)
    return {"success": True}


@router.post("/import-trained")
def import_trained_models(
    db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
):
    logger.info(f"Importing trained models, user={current_user.id}")
    trained_root = Path(settings.TRAINED_MODELS_DIR)
    if not trained_root.exists():
        logger.warning(f"Trained models directory not found: {settings.TRAINED_MODELS_DIR}")
        return {"imported": 0, "skipped": 0, "message": "trained_models not found"}
    imported = 0
    skipped = 0
    for model_dir in sorted(path for path in trained_root.iterdir() if path.is_dir()):
        jobs_dir = model_dir / "jobs"
        if not jobs_dir.exists():
            continue
        for job_dir in sorted(jobs_dir.glob("job_*")):
            final_model_dir = job_dir / "final_model"
            if not final_model_dir.exists():
                skipped += 1
                continue
            adapter_file = final_model_dir / "adapters.safetensors"
            config_file = final_model_dir / "adapter_config.json"
            if not adapter_file.exists() or not config_file.exists():
                skipped += 1
                continue
            identifier = f"{model_dir.name}__{job_dir.name}"
            readable_name = identifier
            try:
                ts = job_dir.name.removeprefix("job_")
                dt = datetime.strptime(ts, "%Y%m%d_%H%M%S")
                readable_name = f"{model_dir.name} LoRA {dt.strftime('%Y-%m-%d %H:%M')}"
            except ValueError:
                pass
            exists = (
                db.query(LocalModel)
                .filter(
                    LocalModel.identifier == identifier,
                    LocalModel.tenant_id == current_user.tenant_id,
                )
                .first()
            )
            if exists:
                skipped += 1
                continue
            model = LocalModel(
                name=readable_name,
                identifier=identifier,
                model_type="lora",
                base_model=model_dir.name,
                model_path=str(final_model_dir),
                status="active",
                tenant_id=current_user.tenant_id,
                asset_kind="adapter",
                runtime_kind="mlx",
                source_kind="trained",
            )
            db.add(model)
            imported += 1
    db.commit()
    logger.info(f"Trained models imported: imported={imported}, skipped={skipped}")
    return {"imported": imported, "skipped": skipped}


@router.delete("/{model_id}")
def delete_local_model(
    model_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
):
    model = (
        db.query(LocalModel)
        .filter(LocalModel.id == model_id, LocalModel.tenant_id == current_user.tenant_id)
        .first()
    )
    if not model:
        raise HTTPException(status_code=404, detail="Model not found")
    db.delete(model)
    db.commit()
    return {"message": "Model deleted successfully"}
