from fastapi import APIRouter, WebSocket, WebSocketDisconnect

from app.core.websocket_manager import manager

from common_logging import get_logger

logger = get_logger(__name__)

router = APIRouter()

@router.websocket('/ws/training/{task_id}')
async def websocket_endpoint(websocket: WebSocket, task_id: str):
    await manager.connect(task_id, websocket)
    logger.bind(task_id=task_id).info("WebSocket connected")
    try:
        while True:
            await websocket.receive_text()
    except WebSocketDisconnect:
        logger.bind(task_id=task_id).info("WebSocket disconnected")
        await manager.disconnect(task_id, websocket)
