#!/usr/bin/env python3
"""
OpenAI-compatible reranking server using sentence-transformers CrossEncoder.
Endpoint: POST /v1/rerank
"""
import os
from contextlib import asynccontextmanager

import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from common_logging import get_logger, setup_logging
from common_metrics import setup_metrics
from pydantic import BaseModel
from sentence_transformers import CrossEncoder

setup_logging()
logger = get_logger(__name__)

MODEL_PATH = os.environ.get(
    "RERANK_MODEL_PATH",
    "/lsinfo/ai/hellotax_ai/llm_service/base_models/Qwen3-Reranker-8B"
)
DEVICE = os.environ.get("RERANK_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
SERVED_MODEL_NAME = os.environ.get("SERVED_MODEL_NAME", "Qwen3-Reranker-8B")
API_KEY = os.environ.get("API_KEY", "")
security = HTTPBearer(auto_error=False)


def verify_key(credentials: HTTPAuthorizationCredentials = Security(security)):
    if API_KEY and (not credentials or credentials.credentials != API_KEY):
        raise HTTPException(status_code=401, detail="Invalid API key")

model: CrossEncoder = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    global model
    logger.info(f"Loading reranker from {MODEL_PATH} on {DEVICE}")
    model = CrossEncoder(
        MODEL_PATH,
        device=DEVICE,
        trust_remote_code=True,
        max_length=1024,
        model_kwargs={"torch_dtype": torch.float16},
    )
    logger.info("Reranker loaded.")
    yield
    del model


app = FastAPI(lifespan=lifespan)
setup_metrics(app, "rerank-server")


class RerankRequest(BaseModel):
    query: str
    documents: list[str] | list[dict]
    model: str = SERVED_MODEL_NAME
    top_n: int = 5


@app.get("/health")
def health():
    return {"status": "ok"}


@app.post("/rerank")
@app.post("/v1/rerank")
def rerank(req: RerankRequest, _=Security(verify_key)):
    if isinstance(req.documents[0], dict):
        texts = [doc.get("text", "") for doc in req.documents]
    else:
        texts = req.documents

    pairs = [[req.query, text] for text in texts]
    scores = model.predict(pairs, convert_to_numpy=True, batch_size=1).tolist()

    results = [{"index": i, "relevance_score": score} for i, score in enumerate(scores)]
    results.sort(key=lambda x: x["relevance_score"], reverse=True)

    return {
        "object": "list",
        "results": results[:req.top_n],
        "model": SERVED_MODEL_NAME,
        "usage": {"total_tokens": len(pairs)},
    }


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", default="0.0.0.0")
    parser.add_argument("--port", type=int, default=8300)
    args = parser.parse_args()
    uvicorn.run(app, host=args.host, port=args.port)
