import json
from typing import Any

from sqlalchemy import create_engine, text
from common_logging import get_logger, log_execution

logger = get_logger(__name__)


def _parse_content(raw: Any) -> dict[str, Any]:
    if isinstance(raw, dict):
        return raw
    if isinstance(raw, str):
        try:
            return json.loads(raw)
        except json.JSONDecodeError:
            return {}
    return {}

def load_sft_dataset(dataset_id: int, db_url: str) -> list[dict[str, str]]:
    logger.bind(dataset_id=dataset_id).info("Loading SFT dataset")
    engine = create_engine(db_url)
    with engine.connect() as conn:
        result = conn.execute(text('\n                SELECT content\n                FROM hub_global.dataset_samples\n                WHERE dataset_id = :dataset_id\n                ORDER BY id\n            '), {'dataset_id': dataset_id})
        samples = []
        for row in result:
            content = _parse_content(row[0])
            prompt = content.get('prompt', '')
            completion = content.get('completion', '') or content.get('response', '')
            if prompt and completion:
                samples.append({'prompt': prompt, 'completion': completion})
        return samples

def load_dpo_dataset(dataset_id: int, db_url: str) -> list[dict[str, Any]]:
    logger.bind(dataset_id=dataset_id).info("Loading DPO dataset")
    engine = create_engine(db_url)
    with engine.connect() as conn:
        result = conn.execute(text('\n                SELECT content\n                FROM hub_global.dataset_samples\n                WHERE dataset_id = :dataset_id\n                ORDER BY id\n            '), {'dataset_id': dataset_id})
        samples = []
        for row in result:
            content = _parse_content(row[0])
            prompt = content.get('prompt', '')
            chosen = content.get('chosen', '')
            rejected = content.get('rejected', '')
            if prompt and chosen and rejected:
                samples.append({'prompt': prompt, 'chosen': chosen, 'rejected': rejected})
        return samples
