import uuid

from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Session

from app.celery_app import celery_app
from app.database import get_db
from app.models.tax_data import DataSource, TaxDocument
from common_logging import get_logger

logger = get_logger(__name__)

router = APIRouter(prefix='/sources', tags=['sources'])

class SourceResponse(BaseModel):
    id: int
    code: str
    name: str
    source_type: str
    region_level: str
    region_code: str | None
    has_categories: bool
    is_active: bool
    last_crawled_at: str | None
    categories: list[dict] = []
    doc_count: int = 0

    class Config:
        from_attributes = True

class SourceCreate(BaseModel):
    code: str
    name: str
    source_type: str
    region_level: str
    region_code: str | None = None
    adapter_class: str
    adapter_config: dict | None = None
    has_categories: bool = False
    crawl_schedule: str | None = None
    request_delay_min: float = 1.0
    request_delay_max: float = 5.0
    max_retries: int = 3

@router.get('/', response_model=list[SourceResponse])
def list_sources(db: Session=Depends(get_db)):
    sources = db.query(DataSource).filter(DataSource.is_active).all()
    result = []
    for s in sources:
        categories = []
        if s.has_categories:
            if s.code == 'shenzhen':
                count = db.query(func.count(TaxDocument.id)).filter(TaxDocument.source_id == s.id).scalar() or 0
                categories = [{'id': 'local_policy', 'name': '地方政策法规库', 'count': count}]
            elif s.code == 'gdtax':
                rows = db.query(TaxDocument.doc_type, func.count(TaxDocument.id)).filter(TaxDocument.source_id == s.id).group_by(TaxDocument.doc_type).all()
                type_counts = dict(rows)
                categories = [{'id': 'local_policy', 'name': '地方政策法规', 'count': type_counts.get('local_policy', 0)}, {'id': 'normative', 'name': '规范性文件', 'count': type_counts.get('normative', 0)}, {'id': 'interpretation', 'name': '政策解读', 'count': type_counts.get('interpretation', 0)}]
            elif s.code == 'chinatax':
                from app.services.tax_data_processor.category_processor import CategoryProcessor
                cp = CategoryProcessor()
                chinatax_source_ids = [sid for sid, in db.query(DataSource.id).filter(DataSource.code.like('chinatax%'), DataSource.is_active).all()]
                rows = db.query(TaxDocument.category_id, func.count(TaxDocument.id)).filter(TaxDocument.source_id.in_(chinatax_source_ids)).group_by(TaxDocument.category_id).all()
                cat_counts = dict(rows)
                categories = [{'id': cat['id'], 'name': cat['name'], 'count': cat_counts.get(cat['id'], 0)} for cat in cp.get_all_categories() if cat['id'] in range(1, 9)]
            else:
                from app.services.tax_data_processor.category_processor import CategoryProcessor
                cp = CategoryProcessor()
                rows = db.query(TaxDocument.category_id, func.count(TaxDocument.id)).filter(TaxDocument.source_id == s.id).group_by(TaxDocument.category_id).all()
                categories = [{'id': cat_id, 'name': cp.get_category_name(cat_id), 'count': cnt} for cat_id, cnt in rows if cat_id]
        doc_count = db.query(func.count(TaxDocument.id)).filter(TaxDocument.source_id == s.id).scalar() or 0
        result.append(SourceResponse(id=s.id, code=s.code, name=s.name, source_type=s.source_type, region_level=s.region_level, region_code=s.region_code, has_categories=s.has_categories, is_active=s.is_active, last_crawled_at=s.last_crawled_at.isoformat() if s.last_crawled_at else None, categories=categories, doc_count=doc_count))
    return result

@router.post('/', response_model=SourceResponse, status_code=201)
def create_source(data: SourceCreate, db: Session=Depends(get_db)):
    if db.query(DataSource).filter(DataSource.code == data.code).first():
        raise HTTPException(400, detail='数据源 code 已存在')
    source = DataSource(**data.model_dump())
    db.add(source)
    db.commit()
    db.refresh(source)
    logger.bind(source_id=source.id).info(f"source created: {source.code}")
    return SourceResponse(id=source.id, code=source.code, name=source.name, source_type=source.source_type, region_level=source.region_level, region_code=source.region_code, has_categories=source.has_categories, is_active=source.is_active, last_crawled_at=None, categories=[])

@router.post('/{source_id}/trigger')
def trigger_source(source_id: int, mode: str='full', start_page: int=None, end_page: int=None, db: Session=Depends(get_db)):
    source = db.query(DataSource).filter(DataSource.id == source_id).first()
    if not source:
        raise HTTPException(404, detail='数据源不存在')
    task_id = str(uuid.uuid4())
    celery_app.send_task('app.tasks.processor_tasks.process_source_task', args=[task_id, source_id, mode, False], kwargs={'start_page': start_page, 'end_page': end_page})
    logger.bind(source_id=source_id).info(f"source triggered: mode={mode}, task_id={task_id}")
    return {'task_id': task_id, 'source_id': source_id, 'mode': mode, 'start_page': start_page, 'end_page': end_page}
