#!/usr/bin/env python3
"""
OpenAI-compatible embedding server using sentence-transformers.
Endpoint: POST /v1/embeddings
"""
import os
from contextlib import asynccontextmanager

import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from common_logging import get_logger, setup_logging
from common_metrics import setup_metrics
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer

setup_logging()
logger = get_logger(__name__)

MODEL_PATH = os.environ.get(
    "EMBEDDING_MODEL_PATH",
    "/lsinfo/ai/hellotax_ai/llm_service/base_models/Qwen3-Embedding-8B"
)
DEVICE = os.environ.get("EMBEDDING_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
SERVED_MODEL_NAME = os.environ.get("SERVED_MODEL_NAME", "Qwen3-Embedding-8B")
API_KEY = os.environ.get("API_KEY", "")
security = HTTPBearer(auto_error=False)


def verify_key(credentials: HTTPAuthorizationCredentials = Security(security)):
    if API_KEY and (not credentials or credentials.credentials != API_KEY):
        raise HTTPException(status_code=401, detail="Invalid API key")

model: SentenceTransformer = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    global model
    logger.info(f"Loading model from {MODEL_PATH} on {DEVICE}")
    model = SentenceTransformer(
        MODEL_PATH,
        device=DEVICE,
        trust_remote_code=True,
        model_kwargs={"torch_dtype": torch.float16}
    )
    logger.info("Model loaded.")
    yield
    del model


app = FastAPI(lifespan=lifespan)
setup_metrics(app, "embedding-server")


class EmbeddingRequest(BaseModel):
    input: str | list[str]
    model: str = SERVED_MODEL_NAME
    encoding_format: str = "float"
    input_type: str = "document"


PROMPT_NAME_ALIASES = {
    "passage": "document",
    "doc": "document",
}


@app.get("/health")
def health():
    return {"status": "ok"}


@app.post("/v1/embeddings")
def create_embeddings(req: EmbeddingRequest, _=Security(verify_key)):
    texts = [req.input] if isinstance(req.input, str) else req.input
    prompt_name = PROMPT_NAME_ALIASES.get(req.input_type, req.input_type)
    embeddings = model.encode(texts, prompt_name=prompt_name, normalize_embeddings=True, convert_to_numpy=True)
    data = [
        {"object": "embedding", "index": i, "embedding": emb.tolist()}
        for i, emb in enumerate(embeddings)
    ]
    total_tokens = sum(len(t.split()) for t in texts)
    return {
        "object": "list",
        "data": data,
        "model": SERVED_MODEL_NAME,
        "usage": {"prompt_tokens": total_tokens, "total_tokens": total_tokens},
    }


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", default="0.0.0.0")
    parser.add_argument("--port", type=int, default=8200)
    args = parser.parse_args()
    uvicorn.run(app, host=args.host, port=args.port)
