import csv
import io
import json
from enum import Enum
from typing import Any

from sqlalchemy.orm import Session
from common_logging import get_logger, log_execution

logger = get_logger(__name__)

class DatasetSplit(str, Enum):
    TRAIN = 'train'
    VALIDATION = 'validation'
    TEST = 'test'

class ExportFormat(str, Enum):
    JSON = 'json'
    CSV = 'csv'
    JSONL = 'jsonl'

class TrainingDatasetManager:

    def __init__(self, db: Session):
        self.db = db

    @log_execution(logger)
    def create_dataset(self, name: str, description: str, dataset_type: str, created_by: int, metadata: dict | None=None) -> dict[str, Any]:
        from app.models import TrainingDataset
        dataset = TrainingDataset(name=name, description=description, dataset_type=dataset_type, version=1, created_by=created_by, total_samples=0, train_samples=0, validation_samples=0, test_samples=0, metadata=metadata or {})
        self.db.add(dataset)
        self.db.commit()
        self.db.refresh(dataset)
        logger.info(f'Created dataset {dataset.id}: {name}')
        return self._dataset_to_dict(dataset)

    def create_version(self, dataset_id: int, description: str, created_by: int) -> dict[str, Any]:
        from app.models import TrainingDataset
        try:
            parent = self.db.query(TrainingDataset).filter(TrainingDataset.id == dataset_id).with_for_update().first()
            if not parent:
                raise ValueError(f'Dataset {dataset_id} not found')
            dataset = TrainingDataset(name=parent.name, description=description, dataset_type=parent.dataset_type, version=parent.version + 1, parent_version_id=dataset_id, created_by=created_by, total_samples=0, train_samples=0, validation_samples=0, test_samples=0, meta_data=parent.meta_data.copy())
            self.db.add(dataset)
            self.db.commit()
            self.db.refresh(dataset)
            logger.info(f'Created dataset version {dataset.version} for dataset {parent.name}')
            return self._dataset_to_dict(dataset)
        except Exception:
            self.db.rollback()
            raise

    def add_sample(self, dataset_id: int, content: str, label: str, split: DatasetSplit, source_task_id: int | None=None, metadata: dict | None=None) -> dict[str, Any]:
        from app.models import DatasetSample, TrainingDataset
        dataset = self.db.query(TrainingDataset).filter(TrainingDataset.id == dataset_id).first()
        if not dataset:
            raise ValueError(f'Dataset {dataset_id} not found')
        sample = DatasetSample(dataset_id=dataset_id, content=content, label=label, split=split, source_task_id=source_task_id, metadata=metadata or {})
        self.db.add(sample)
        dataset.total_samples += 1
        if split == DatasetSplit.TRAIN:
            dataset.train_samples += 1
        elif split == DatasetSplit.VALIDATION:
            dataset.validation_samples += 1
        elif split == DatasetSplit.TEST:
            dataset.test_samples += 1
        self.db.commit()
        self.db.refresh(sample)
        logger.info(f'Added sample to dataset {dataset_id}, split={split}')
        return self._sample_to_dict(sample)

    @log_execution(logger)
    def split_dataset(self, dataset_id: int, train_ratio: float=0.7, validation_ratio: float=0.15, test_ratio: float=0.15) -> dict[str, int]:
        import random

        from app.models import DatasetSample
        if abs(train_ratio + validation_ratio + test_ratio - 1.0) > 0.01:
            raise ValueError('Split ratios must sum to 1.0')
        samples = self.db.query(DatasetSample).filter(DatasetSample.dataset_id == dataset_id).all()
        random.shuffle(samples)
        total = len(samples)
        train_end = int(total * train_ratio)
        val_end = train_end + int(total * validation_ratio)
        for i, sample in enumerate(samples):
            if i < train_end:
                sample.split = DatasetSplit.TRAIN
            elif i < val_end:
                sample.split = DatasetSplit.VALIDATION
            else:
                sample.split = DatasetSplit.TEST
        self.db.commit()
        logger.info(f'Split dataset {dataset_id}: train={train_end}, val={val_end - train_end}, test={total - val_end}')
        return {'train': train_end, 'validation': val_end - train_end, 'test': total - val_end}

    def get_quality_metrics(self, dataset_id: int) -> dict[str, Any]:
        from app.models import DatasetSample, TrainingDataset
        dataset = self.db.query(TrainingDataset).filter(TrainingDataset.id == dataset_id).first()
        if not dataset:
            raise ValueError(f'Dataset {dataset_id} not found')
        samples = self.db.query(DatasetSample).filter(DatasetSample.dataset_id == dataset_id).all()
        label_dist = {}
        for sample in samples:
            label_dist[sample.label] = label_dist.get(sample.label, 0) + 1
        return {'total_samples': dataset.total_samples, 'train_samples': dataset.train_samples, 'validation_samples': dataset.validation_samples, 'test_samples': dataset.test_samples, 'label_distribution': label_dist, 'num_labels': len(label_dist)}

    def export_dataset(self, dataset_id: int, format: ExportFormat, split: DatasetSplit | None=None) -> str:
        from app.models import DatasetSample
        query = self.db.query(DatasetSample).filter(DatasetSample.dataset_id == dataset_id)
        if split:
            query = query.filter(DatasetSample.split == split)
        samples = query.all()
        if format == ExportFormat.JSON:
            return json.dumps([self._sample_to_dict(s) for s in samples], indent=2)
        elif format == ExportFormat.JSONL:
            return '\n'.join([json.dumps(self._sample_to_dict(s)) for s in samples])
        elif format == ExportFormat.CSV:
            output = io.StringIO()
            writer = csv.DictWriter(output, fieldnames=['id', 'content', 'label', 'split'])
            writer.writeheader()
            for s in samples:
                writer.writerow({'id': s.id, 'content': s.content, 'label': s.label, 'split': s.split})
            return output.getvalue()

    def _dataset_to_dict(self, dataset) -> dict[str, Any]:
        return {'id': dataset.id, 'name': dataset.name, 'description': dataset.description, 'dataset_type': dataset.dataset_type, 'version': dataset.version, 'parent_version_id': dataset.parent_version_id, 'total_samples': dataset.total_samples, 'train_samples': dataset.train_samples, 'validation_samples': dataset.validation_samples, 'test_samples': dataset.test_samples, 'created_by': dataset.created_by, 'created_at': dataset.created_at.isoformat() if dataset.created_at else None, 'updated_at': dataset.updated_at.isoformat() if dataset.updated_at else None, 'meta_data': dataset.meta_data}

    def _sample_to_dict(self, sample) -> dict[str, Any]:
        return {'id': sample.id, 'dataset_id': sample.dataset_id, 'content': sample.content, 'label': sample.label, 'split': sample.split, 'source_task_id': sample.source_task_id, 'created_at': sample.created_at.isoformat() if sample.created_at else None, 'meta_data': sample.meta_data}

def get_training_dataset_manager(db: Session) -> TrainingDatasetManager:
    return TrainingDatasetManager(db)
