
from fastapi import WebSocket
from common_logging import get_logger

logger = get_logger(__name__)


class WebSocketManager:

    def __init__(self):
        self.connections: dict[str, list[WebSocket]] = {}

    async def connect(self, task_id: str, websocket: WebSocket):
        await websocket.accept()
        if task_id not in self.connections:
            self.connections[task_id] = []
        self.connections[task_id].append(websocket)

    async def disconnect(self, task_id: str, websocket: WebSocket):
        if task_id in self.connections:
            self.connections[task_id].remove(websocket)
            if not self.connections[task_id]:
                del self.connections[task_id]

    async def broadcast(self, task_id: str, message: dict):
        if task_id in self.connections:
            for connection in self.connections[task_id]:
                try:
                    await connection.send_json(message)
                except Exception:
                    logger.warning("WebSocket send failed", task_id=task_id)
manager = WebSocketManager()
