from __future__ import annotations

import csv
import re
import socket
import subprocess
import time
from datetime import UTC, datetime
from io import StringIO
from pathlib import Path
from threading import Lock
from typing import Any

import httpx

from app.config import LLM_SERVICE_DIR
from common_logging import get_logger

logger = get_logger(__name__)
DEFAULT_EMBEDDING_MODEL_PATH = LLM_SERVICE_DIR / "base_models" / "bge-m3"
DEFAULT_RERANK_MODEL_PATH = LLM_SERVICE_DIR / "base_models" / "bge-reranker-v2-m3"
MODEL_PROCESS_HINTS = (
    "vllm.entrypoints.openai.api_server",
    "embedding_server.py",
    "rerank_server.py",
    "mlx_lm.server",
    "llamafactory",
    "torchrun",
    "deepspeed",
    "accelerate launch",
    "--model_name_or_path",
    "--served-model-name",
)


def _utc_now_iso() -> str:
    return datetime.now(UTC).isoformat()


class GpuRuntimeMonitoringService:

    def __init__(self, cache_ttl_seconds: float = 3.0):
        self._cache_ttl_seconds = cache_ttl_seconds
        self._cache_lock = Lock()
        self._cached_snapshot: dict[str, Any] | None = None
        self._cached_at = 0.0

    def get_snapshot(self) -> dict[str, Any]:
        now = time.monotonic()
        with self._cache_lock:
            if self._cached_snapshot and now - self._cached_at < self._cache_ttl_seconds:
                return self._cached_snapshot
            snapshot = self._build_snapshot()
            self._cached_snapshot = snapshot
            self._cached_at = now
            return snapshot

    def _build_snapshot(self) -> dict[str, Any]:
        gpu_rows = self._collect_gpu_rows()
        gpu_by_uuid = {row["uuid"]: row for row in gpu_rows}
        processes = self._collect_runtime_processes()
        for compute_row in self._collect_compute_process_rows():
            pid = compute_row["pid"]
            target_pid = self._find_model_parent_process(pid, processes)
            process = processes.get(target_pid)
            if process is None:
                cmdline = self._read_process_cmdline(target_pid) or compute_row["process_name"]
                process = self._describe_process(target_pid, cmdline)
                processes[target_pid] = process
            process["gpu_uuids"].add(compute_row["gpu_uuid"])
            process["gpu_memory_mb"] += compute_row["used_gpu_memory_mb"]
        models = self._build_models(processes, gpu_by_uuid)
        devices = self._build_devices(gpu_rows, processes)
        overview = self._build_overview(devices, models)
        return {"overview": overview, "devices": devices, "models": models}

    def _build_overview(
        self, devices: list[dict[str, Any]], models: list[dict[str, Any]]
    ) -> dict[str, Any]:
        alert_count = sum(1 for device in devices if device["status"] in {"warning", "offline"})
        alert_count += sum(1 for model in models if model["status"] == "failed")
        running_model_count = sum(1 for model in models if model["status"] != "failed")
        return {
            "totalGpuCount": len(devices),
            "onlineGpuCount": sum(1 for device in devices if device["status"] != "offline"),
            "runningModelCount": running_model_count,
            "alertCount": alert_count,
            "updatedAt": _utc_now_iso(),
        }

    def _build_devices(
        self, gpu_rows: list[dict[str, Any]], processes: dict[int, dict[str, Any]]
    ) -> list[dict[str, Any]]:
        running_models_by_gpu: dict[str, set[str]] = {}
        for process in processes.values():
            model_name = process["model_name"]
            if model_name.startswith("GPU Process "):
                continue
            for gpu_uuid in process["gpu_uuids"]:
                running_models_by_gpu.setdefault(gpu_uuid, set()).add(model_name)
        host_name = socket.gethostname()
        devices: list[dict[str, Any]] = []
        for gpu_row in gpu_rows:
            running_models = sorted(running_models_by_gpu.get(gpu_row["uuid"], set()))
            memory_total_mb = gpu_row["memory_total_mb"]
            memory_used_mb = gpu_row["memory_used_mb"]
            memory_usage_percent = 0
            if memory_total_mb > 0:
                memory_usage_percent = round(memory_used_mb / memory_total_mb * 100)
            status = self._derive_gpu_status(
                memory_usage_percent=memory_usage_percent,
                utilization_percent=gpu_row["utilization_percent"],
                running_model_names=running_models,
                memory_used_mb=memory_used_mb,
            )
            devices.append(
                {
                    "id": f"gpu-{gpu_row['index']}",
                    "name": f"{gpu_row['name']} #{gpu_row['index']}",
                    "hostName": host_name,
                    "status": status,
                    "memoryUsedGb": round(memory_used_mb / 1024, 1),
                    "memoryTotalGb": round(memory_total_mb / 1024, 1),
                    "memoryUsagePercent": memory_usage_percent,
                    "utilizationPercent": gpu_row["utilization_percent"],
                    "runningModelNames": running_models,
                }
            )
        return devices

    def _build_models(
        self, processes: dict[int, dict[str, Any]], gpu_by_uuid: dict[str, dict[str, Any]]
    ) -> list[dict[str, Any]]:
        models: list[dict[str, Any]] = []
        for process in processes.values():
            model_name = process["model_name"]
            if model_name.startswith("GPU Process "):
                continue
            gpu_rows = [
                gpu_by_uuid[gpu_uuid]
                for gpu_uuid in sorted(process["gpu_uuids"])
                if gpu_uuid in gpu_by_uuid
            ]
            gpu_ids = ", ".join(str(row["index"]) for row in gpu_rows) or "N/A"
            gpu_names = (
                ", ".join(f"{row['name']} #{row['index']}" for row in gpu_rows) or "Unassigned"
            )
            models.append(
                {
                    "id": f"pid-{process['pid']}",
                    "name": model_name,
                    "version": self._build_version_label(process),
                    "status": self._derive_model_status(process),
                    "gpuId": gpu_ids,
                    "gpuName": gpu_names,
                    "memoryUsedGb": round(process["gpu_memory_mb"] / 1024, 1),
                    "startedAt": process["started_at"],
                }
            )
        status_order = {"running": 0, "loading": 1, "idle": 2, "failed": 3}
        models.sort(
            key=lambda item: (
                status_order.get(item["status"], 99),
                -item["memoryUsedGb"],
                item["name"],
            )
        )
        return models

    def _collect_gpu_rows(self) -> list[dict[str, Any]]:
        output = self._run_command(
            [
                "nvidia-smi",
                "--query-gpu=index,name,uuid,memory.used,memory.total,utilization.gpu",
                "--format=csv,noheader,nounits",
            ]
        )
        if not output:
            return []
        rows: list[dict[str, Any]] = []
        for record in self._parse_csv_output(output):
            if len(record) < 6:
                continue
            rows.append(
                {
                    "index": self._safe_int(record[0]),
                    "name": record[1].strip(),
                    "uuid": record[2].strip(),
                    "memory_used_mb": self._safe_int(record[3]),
                    "memory_total_mb": self._safe_int(record[4]),
                    "utilization_percent": self._safe_int(record[5]),
                }
            )
        return rows

    def _collect_compute_process_rows(self) -> list[dict[str, Any]]:
        output = self._run_command(
            [
                "nvidia-smi",
                "--query-compute-apps=gpu_uuid,pid,process_name,used_gpu_memory",
                "--format=csv,noheader,nounits",
            ]
        )
        if not output or "No running processes found" in output:
            return []
        rows: list[dict[str, Any]] = []
        for record in self._parse_csv_output(output):
            if len(record) < 4:
                continue
            rows.append(
                {
                    "gpu_uuid": record[0].strip(),
                    "pid": self._safe_int(record[1]),
                    "process_name": record[2].strip(),
                    "used_gpu_memory_mb": self._safe_int(record[3]),
                }
            )
        return [row for row in rows if row["pid"] > 0]

    def _collect_runtime_processes(self) -> dict[int, dict[str, Any]]:
        output = self._run_command(["ps", "-eo", "pid=,args="])
        if not output:
            return {}
        processes: dict[int, dict[str, Any]] = {}
        for line in output.splitlines():
            line = line.strip()
            if not line:
                continue
            parts = line.split(None, 1)
            if len(parts) != 2:
                continue
            pid = self._safe_int(parts[0])
            args = parts[1].strip()
            if pid <= 0 or not self._looks_like_model_process(args):
                continue
            processes[pid] = self._describe_process(pid, args)
        return processes

    def _describe_process(self, pid: int, args: str) -> dict[str, Any]:
        env_map = self._read_process_env(pid)
        runtime = self._infer_runtime(args)
        port = self._extract_flag_value(args, "--port")
        started_at = self._get_process_started_at(pid)
        return {
            "pid": pid,
            "args": args,
            "runtime": runtime,
            "port": self._safe_int(port) if port else None,
            "healthy": self._check_service_health(self._safe_int(port) if port else None, runtime),
            "started_at": started_at,
            "gpu_memory_mb": 0,
            "gpu_uuids": set(),
            "model_name": self._infer_model_name(pid, args, env_map, runtime),
        }

    def _looks_like_model_process(self, args: str) -> bool:
        normalized = args.lower()
        return any(hint.lower() in normalized for hint in MODEL_PROCESS_HINTS)

    def _infer_runtime(self, args: str) -> str:
        normalized = args.lower()
        if "vllm.entrypoints.openai.api_server" in normalized:
            return "vLLM"
        if "embedding_server.py" in normalized:
            return "Embedding"
        if "rerank_server.py" in normalized:
            return "Reranker"
        if "mlx_lm.server" in normalized:
            return "MLX"
        if "llamafactory" in normalized:
            return "LLaMA-Factory"
        if (
            "torchrun" in normalized
            or "deepspeed" in normalized
            or "accelerate launch" in normalized
        ):
            return "Training"
        return "Runtime"

    def _infer_model_name(self, pid: int, args: str, env_map: dict[str, str], runtime: str) -> str:
        served_model_name = self._extract_flag_value(args, "--served-model-name") or env_map.get(
            "SERVED_MODEL_NAME"
        )
        if served_model_name:
            return served_model_name
        model_path = (
            self._extract_flag_value(args, "--model")
            or self._extract_flag_value(args, "--model_name_or_path")
            or env_map.get("MODEL_PATH")
            or env_map.get("MODEL_NAME_OR_PATH")
            or env_map.get("EMBEDDING_MODEL_PATH")
            or env_map.get("RERANK_MODEL_PATH")
        )
        if model_path:
            return Path(model_path).name
        if runtime == "Embedding":
            return Path(env_map.get("EMBEDDING_MODEL_PATH", str(DEFAULT_EMBEDDING_MODEL_PATH))).name
        if runtime == "Reranker":
            return Path(env_map.get("RERANK_MODEL_PATH", str(DEFAULT_RERANK_MODEL_PATH))).name
        path_match = re.search("(/[^\\s]+(?:base_models|trained_models)[^\\s]*)", args)
        if path_match:
            return Path(path_match.group(1)).name
        return f"GPU Process {pid}"

    def _build_version_label(self, process: dict[str, Any]) -> str:
        runtime = process["runtime"]
        port = process["port"]
        if port:
            return f"{runtime} · :{port}"
        return f"{runtime} · PID {process['pid']}"

    def _derive_gpu_status(
        self,
        memory_usage_percent: int,
        utilization_percent: int,
        running_model_names: list[str],
        memory_used_mb: int,
    ) -> str:
        if memory_usage_percent >= 90 or utilization_percent >= 95:
            return "warning"
        if running_model_names or utilization_percent > 0 or memory_used_mb > 0:
            return "busy"
        return "online"

    def _derive_model_status(self, process: dict[str, Any]) -> str:
        healthy = process["healthy"]
        has_gpu_memory = process["gpu_memory_mb"] > 0 or bool(process["gpu_uuids"])
        if healthy is True:
            return "running"
        if healthy is False and process["port"]:
            return "loading" if self._started_recently(process["started_at"]) else "failed"
        if has_gpu_memory:
            return "running"
        return "idle"

    def _started_recently(self, started_at: str, threshold_seconds: int = 300) -> bool:
        try:
            started_dt = datetime.fromisoformat(started_at)
        except ValueError:
            return False
        if started_dt.tzinfo is None:
            started_dt = started_dt.replace(tzinfo=UTC)
        return (datetime.now(UTC) - started_dt).total_seconds() < threshold_seconds

    def _check_service_health(self, port: int | None, runtime: str) -> bool | None:
        if not port:
            return None
        urls = [f"http://127.0.0.1:{port}/health"]
        if runtime in {"vLLM", "MLX"}:
            urls.append(f"http://127.0.0.1:{port}/v1/models")
        for url in urls:
            try:
                response = httpx.get(url, timeout=0.8)
                if response.status_code < 500:
                    return response.is_success
            except httpx.HTTPError:
                continue
        return False

    def _extract_flag_value(self, args: str, flag: str) -> str | None:
        pattern = re.compile(f"""{re.escape(flag)}(?:=|\\s+)(\\"[^\\"]+\\"|'[^']+'|\\S+)""")
        match = pattern.search(args)
        if not match:
            return None
        return match.group(1).strip("'\"")

    def _get_process_started_at(self, pid: int) -> str:
        output = self._run_command(["ps", "-p", str(pid), "-o", "lstart="])
        if not output:
            return _utc_now_iso()
        raw_value = output.strip()
        try:
            started_at = datetime.strptime(raw_value, "%a %b %d %H:%M:%S %Y")
            return started_at.astimezone(UTC).isoformat()
        except ValueError:
            return raw_value

    def _read_process_cmdline(self, pid: int) -> str | None:
        cmdline_path = Path(f"/proc/{pid}/cmdline")
        try:
            content = cmdline_path.read_text(encoding="utf-8", errors="ignore")
        except OSError:
            return None
        return " ".join(part for part in content.split("\x00") if part).strip() or None

    def _find_model_parent_process(
        self, pid: int, known_processes: dict[int, dict[str, Any]]
    ) -> int:
        visited = set()
        current = pid
        while current > 1 and current not in visited:
            if current in known_processes:
                return current
            visited.add(current)
            try:
                stat_content = Path(f"/proc/{current}/stat").read_text(
                    encoding="utf-8", errors="ignore"
                )
                parts = stat_content.split(")")
                if len(parts) >= 2:
                    fields = parts[1].strip().split()
                    if len(fields) >= 2:
                        current = self._safe_int(fields[1])
                        continue
            except OSError:
                pass
            break
        return pid

    def _read_process_env(self, pid: int) -> dict[str, str]:
        env_path = Path(f"/proc/{pid}/environ")
        try:
            raw = env_path.read_bytes()
        except OSError:
            return {}
        env_map: dict[str, str] = {}
        for item in raw.split(b"\x00"):
            if not item or b"=" not in item:
                continue
            key, value = item.split(b"=", 1)
            env_map[key.decode("utf-8", errors="ignore")] = value.decode("utf-8", errors="ignore")
        return env_map

    def _parse_csv_output(self, output: str) -> list[list[str]]:
        return [row for row in csv.reader(StringIO(output)) if row]

    def _safe_int(self, value: Any) -> int:
        try:
            return int(str(value).strip())
        except (TypeError, ValueError):
            return 0

    def _run_command(self, command: list[str], timeout: float = 3.0) -> str | None:
        try:
            result = subprocess.run(
                command, check=False, capture_output=True, text=True, timeout=timeout
            )
        except (OSError, subprocess.SubprocessError) as exc:
            logger.warning("GPU runtime command failed: %s", exc)
            return None
        if result.returncode != 0:
            stderr = result.stderr.strip()
            if stderr:
                logger.info("GPU runtime command returned %s: %s", result.returncode, stderr)
            return None
        return result.stdout.strip()


_gpu_runtime_service: GpuRuntimeMonitoringService | None = None


def get_gpu_runtime_service() -> GpuRuntimeMonitoringService:
    global _gpu_runtime_service
    if _gpu_runtime_service is None:
        _gpu_runtime_service = GpuRuntimeMonitoringService()
    return _gpu_runtime_service
