
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session

from app.api.deps import User, get_current_user
from app.core.database import get_db
from app.schemas import DatasetCreate, DatasetResponse, SampleCreate
from app.services.training_dataset_manager import get_training_dataset_manager

from common_logging import get_logger

logger = get_logger(__name__)

router = APIRouter()

@router.post('/', response_model=DatasetResponse)
async def create_dataset(dataset: DatasetCreate, db: Session=Depends(get_db), current_user: User=Depends(get_current_user)):
    manager = get_training_dataset_manager(db)
    result = manager.create_dataset(name=dataset.name, description=dataset.description or '', dataset_type=dataset.dataset_type, created_by=current_user.id, metadata=dataset.metadata)
    logger.bind(dataset_id=result["id"]).info("Dataset created")
    return result

@router.get('/')
async def list_datasets(db: Session=Depends(get_db)):
    from app.models import TrainingDataset
    datasets = db.query(TrainingDataset).all()
    manager = get_training_dataset_manager(db)
    return [manager._dataset_to_dict(d) for d in datasets]

@router.get('/{dataset_id}')
async def get_dataset(dataset_id: int, db: Session=Depends(get_db)):
    from app.models import TrainingDataset
    dataset = db.query(TrainingDataset).filter(TrainingDataset.id == dataset_id).first()
    if not dataset:
        raise HTTPException(status_code=404, detail='Dataset not found')
    manager = get_training_dataset_manager(db)
    return manager._dataset_to_dict(dataset)

@router.get('/{dataset_id}/versions')
async def get_versions(dataset_id: int, db: Session=Depends(get_db)):
    from app.models import TrainingDataset
    versions = db.query(TrainingDataset).filter(TrainingDataset.parent_version_id == dataset_id).all()
    manager = get_training_dataset_manager(db)
    return [manager._dataset_to_dict(v) for v in versions]

@router.get('/{dataset_id}/samples')
async def get_samples(dataset_id: int, split: str | None=Query(None), db: Session=Depends(get_db)):
    from app.models import DatasetSample
    query = db.query(DatasetSample).filter(DatasetSample.dataset_id == dataset_id)
    if split:
        query = query.filter(DatasetSample.split == split)
    samples = query.all()
    manager = get_training_dataset_manager(db)
    return [manager._sample_to_dict(s) for s in samples]

@router.post('/{dataset_id}/samples')
async def add_sample(dataset_id: int, sample: SampleCreate, db: Session=Depends(get_db), current_user: User=Depends(get_current_user)):
    manager = get_training_dataset_manager(db)
    try:
        return manager.add_sample(dataset_id=dataset_id, content=sample.content, label=sample.label, split=sample.split, source_task_id=sample.source_task_id, metadata=sample.metadata)
    except ValueError as e:
        raise HTTPException(status_code=404, detail=str(e)) from None

@router.post('/{dataset_id}/split')
async def split_dataset(dataset_id: int, train_ratio: float=0.7, validation_ratio: float=0.15, test_ratio: float=0.15, db: Session=Depends(get_db)):
    manager = get_training_dataset_manager(db)
    try:
        return manager.split_dataset(dataset_id, train_ratio, validation_ratio, test_ratio)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e)) from None

@router.get('/{dataset_id}/quality')
async def get_quality_metrics(dataset_id: int, db: Session=Depends(get_db)):
    manager = get_training_dataset_manager(db)
    try:
        return manager.get_quality_metrics(dataset_id)
    except ValueError as e:
        raise HTTPException(status_code=404, detail=str(e)) from None

@router.delete('/{dataset_id}')
async def delete_dataset(dataset_id: int, db: Session=Depends(get_db), current_user: User=Depends(get_current_user)):
    from app.models import DatasetSample, TrainingDataset
    dataset = db.query(TrainingDataset).filter(TrainingDataset.id == dataset_id).first()
    if not dataset:
        raise HTTPException(status_code=404, detail='Dataset not found')
    db.query(DatasetSample).filter(DatasetSample.dataset_id == dataset_id).delete()
    db.delete(dataset)
    db.commit()
    logger.bind(dataset_id=dataset_id).info("Dataset deleted")
    return {'message': 'Dataset deleted'}

@router.get('/{dataset_id}/export')
async def export_dataset(dataset_id: int, format: str=Query('json'), split: str | None=Query(None), db: Session=Depends(get_db), current_user: User=Depends(get_current_user)):
    manager = get_training_dataset_manager(db)
    content = manager.export_dataset(dataset_id, format, split)
    media_types = {'json': 'application/json', 'jsonl': 'application/x-ndjson', 'csv': 'text/csv'}
    return StreamingResponse(iter([content]), media_type=media_types.get(format, 'text/plain'), headers={'Content-Disposition': f'attachment; filename=dataset_{dataset_id}.{format}'})
