from datetime import date
from pathlib import Path

from fastapi import APIRouter, Body, Depends, HTTPException, Query
from fastapi.responses import FileResponse
from pydantic import BaseModel
from sqlalchemy import func, or_
from sqlalchemy.orm import Session

from app.config import settings
from app.database import get_db
from app.models.tax_data import TaxDocument, TaxDocumentVersion
from app.services.tax_data_processor.category_processor import CategoryProcessor
from app.services.tax_data_processor.knowledge_base_client import KnowledgeBaseClient
from app.services.tax_data_processor.relation_builder import (

    can_apply_legal_relation,
    is_regulation_like,
)
from app.services.tax_data_processor.relation_identity import is_same_legal_reference
from app.tasks.processor_tasks import (
    import_to_knowledge_base_task,
    reprocess_from_local_task,
    strip_ocr_content_task,
)

from common_logging import get_logger

logger = get_logger(__name__)

router = APIRouter()

def _build_timeline(db: Session, doc_id: int, max_nodes: int=50) -> list:
    from sqlalchemy import text as sa_text
    root = db.query(TaxDocument).filter(TaxDocument.id == doc_id).first()
    if not root or not is_regulation_like(root):
        return []
    family: dict = {}
    queue: list = [doc_id]
    while queue and len(family) < max_nodes:
        current_id = queue.pop(0)
        if current_id in family:
            continue
        current = db.query(TaxDocument).filter(TaxDocument.id == current_id).first()
        if not current or not is_regulation_like(current):
            continue
        family[current_id] = current
        if current.supersedes:
            for item in current.supersedes:
                anc_id = item.get('doc_id') if isinstance(item, dict) else None
                if not anc_id or anc_id in family:
                    continue
                ancestor = db.query(TaxDocument).filter(TaxDocument.id == anc_id).first()
                if ancestor and can_apply_legal_relation(current, ancestor):
                    queue.append(anc_id)
        rows = db.execute(sa_text(f"""SELECT id FROM tax_documents WHERE supersedes @> '[{{"doc_id": {current_id}}}]'""")).fetchall()
        for row in rows:
            source_id = row[0]
            if source_id in family:
                continue
            source = db.query(TaxDocument).filter(TaxDocument.id == source_id).first()
            if source and can_apply_legal_relation(source, current):
                queue.append(source_id)
    if len(family) <= 1:
        return []
    from datetime import date as date_min
    sorted_docs = sorted(family.values(), key=lambda d: d.issue_date or date_min.min, reverse=True)
    return [{'id': d.id, 'title': d.title, 'doc_number': d.doc_number, 'doc_status': d.doc_status, 'issue_date': d.issue_date.isoformat() if d.issue_date else None, 'is_current': d.id == doc_id} for d in sorted_docs]

class BatchImportRequest(BaseModel):
    doc_ids: list[int]

def _resolve_references(db: Session, refs: list) -> list:
    result = []
    for r in refs:
        if not r.get('doc_id') and r.get('source_url'):
            row = db.query(TaxDocument.id, TaxDocument.doc_status).filter(TaxDocument.source_url == r['source_url']).first()
            if row:
                r = {**r, 'doc_id': row.id}
                if row.doc_status == 'obsolete':
                    continue
        elif r.get('doc_id'):
            status = db.query(TaxDocument.doc_status).filter(TaxDocument.id == r['doc_id']).scalar()
            if status == 'obsolete':
                continue
        result.append(r)
    return result

def _has_images(doc: TaxDocument) -> bool:
    if doc.doc_type == 'qa':
        return False
    md = doc.content_markdown or ''
    return '![' in md

def _doc_summary(doc: TaxDocument) -> dict:
    return {'id': doc.id, 'source_url': doc.source_url, 'category_id': doc.category_id, 'doc_number': doc.doc_number, 'title': doc.title, 'issuing_authority': doc.issuing_authority, 'issue_date': doc.issue_date.isoformat() if doc.issue_date else None, 'effective_date': doc.effective_date.isoformat() if doc.effective_date else None, 'processing_status': doc.processing_status, 'doc_status': doc.doc_status, 'is_imported': doc.is_imported, 'has_attachments': bool(doc.attachments), 'has_images': _has_images(doc), 'interpretation_form': doc.interpretation_form, 'doc_type': doc.doc_type, 'created_at': doc.created_at.isoformat() if doc.created_at else None, 'updated_at': doc.updated_at.isoformat() if doc.updated_at else None}

@router.get('/categories')
async def get_categories(db: Session=Depends(get_db)):
    category_processor = CategoryProcessor()
    categories = []
    for c in category_processor.get_all_categories():
        real_count = db.query(TaxDocument).filter(TaxDocument.category_id == c['id']).count()
        categories.append({'id': c['id'], 'name': c['name'], 'count': real_count, 'type': c['type']})
    return {'categories': categories}

@router.get('')
async def list_documents(category_id: int | None=None, is_imported: bool | None=None, doc_number: str | None=Query(None, description='文号精确或模糊查询'), issue_date_from: date | None=Query(None, description='成文日期起始（含）'), issue_date_to: date | None=Query(None, description='成文日期截止（含）'), issuing_authority: str | None=Query(None, description='发文机关关键词'), has_attachments: bool | None=Query(None, description='是否有附件'), doc_status: str | None=Query(None, description='文件状态 effective/obsolete/partially_obsolete/amended'), attachment_type: str | None=Query(None, description='附件类型 pdf/doc/docx/xls/xlsx/wps/ppt/pptx'), has_images: bool | None=Query(None, description='正文是否含图片'), source_id: int | None=None, doc_type: str | None=None, region_code: str | None=None, interpretation_form: str | None=None, skip: int=0, limit: int=Query(20, le=200), db: Session=Depends(get_db)):
    query = db.query(TaxDocument)
    if category_id is not None:
        query = query.filter(TaxDocument.category_id == category_id)
    if is_imported is not None:
        query = query.filter(TaxDocument.is_imported == is_imported)
    if doc_number:
        query = query.filter(TaxDocument.doc_number.ilike(f'%{doc_number}%'))
    if issue_date_from:
        query = query.filter(TaxDocument.issue_date >= issue_date_from)
    if issue_date_to:
        query = query.filter(TaxDocument.issue_date <= issue_date_to)
    if issuing_authority:
        query = query.filter(TaxDocument.issuing_authority.ilike(f'%{issuing_authority}%'))
    if has_attachments is True:
        query = query.filter(TaxDocument.attachments.isnot(None))
    elif has_attachments is False:
        query = query.filter(or_(TaxDocument.attachments.is_(None), TaxDocument.attachments == []))
    if doc_status:
        query = query.filter(TaxDocument.doc_status == doc_status)
    if source_id is not None:
        query = query.filter(TaxDocument.source_id == source_id)
    if doc_type is not None:
        query = query.filter(TaxDocument.doc_type == doc_type)
    if region_code is not None:
        query = query.filter(TaxDocument.region_code == region_code)
    if interpretation_form is not None:
        query = query.filter(TaxDocument.interpretation_form == interpretation_form)
    total = query.count()
    if attachment_type or has_images is not None:
        docs = query.order_by(TaxDocument.issue_date.desc().nullslast(), TaxDocument.id.desc()).all()
        if attachment_type:
            docs = [d for d in docs if d.attachments and any(a.get('type') == attachment_type for a in d.attachments)]
        if has_images is not None:
            docs = [d for d in docs if _has_images(d) == has_images]
        filtered_total = len(docs)
        docs = docs[skip:skip + limit]
    else:
        docs = query.order_by(TaxDocument.issue_date.desc().nullslast(), TaxDocument.id.desc()).offset(skip).limit(limit).all()
        filtered_total = total
    return {'total': filtered_total, 'skip': skip, 'limit': limit, 'items': [_doc_summary(doc) for doc in docs]}

@router.get('/search')
async def search_documents(q: str=Query(..., min_length=1, description='全文关键词'), category_id: int | None=None, issue_date_from: date | None=Query(None), issue_date_to: date | None=Query(None), attachment_type: str | None=Query(None), has_images: bool | None=Query(None), skip: int=0, limit: int=Query(20, le=200), db: Session=Depends(get_db)):
    import re

    is_doc_number = bool(re.search('[\\u4e00-\\u9fa5]{2,10}[〔\\[\\（(]\\d{4}[〕\\]\\）)]\\d+号|(?:国务院令|主席令)第\\d+号', q))
    if is_doc_number:
        query = db.query(TaxDocument).filter(TaxDocument.doc_number.ilike(f'{q}%'))
    else:
        query = db.query(TaxDocument).filter(or_(TaxDocument.title.ilike(f'%{q}%'), TaxDocument.doc_number.ilike(f'%{q}%'), TaxDocument.issuing_authority.ilike(f'%{q}%')))
    if category_id is not None:
        query = query.filter(TaxDocument.category_id == category_id)
    if issue_date_from:
        query = query.filter(TaxDocument.issue_date >= issue_date_from)
    if issue_date_to:
        query = query.filter(TaxDocument.issue_date <= issue_date_to)
    total = query.count()
    if attachment_type or has_images is not None:
        docs = query.order_by(TaxDocument.issue_date.desc().nullslast(), TaxDocument.id.desc()).all()
        if attachment_type:
            docs = [d for d in docs if d.attachments and any(a.get('type') == attachment_type for a in d.attachments)]
        if has_images is not None:
            docs = [d for d in docs if _has_images(d) == has_images]
        filtered_total = len(docs)
        docs = docs[skip:skip + limit]
    else:
        docs = query.order_by(TaxDocument.issue_date.desc().nullslast(), TaxDocument.id.desc()).offset(skip).limit(limit).all()
        filtered_total = total
    return {'total': filtered_total, 'skip': skip, 'limit': limit, 'is_doc_number_search': is_doc_number, 'items': [_doc_summary(doc) for doc in docs]}

@router.post('/strip-ocr-content')
async def strip_ocr_content(db: Session=Depends(get_db)):
    task = strip_ocr_content_task.delay()
    return {'message': 'OCR 内容清理任务已启动', 'celery_task_id': task.id}

@router.get('/{doc_id}')
async def get_document_detail(doc_id: int, db: Session=Depends(get_db)):
    doc = db.query(TaxDocument).filter(TaxDocument.id == doc_id).first()
    if not doc:
        raise HTTPException(status_code=404, detail='文档不存在')
    superseded_by_title: str | None = None
    suppress_external_superseded_by = False
    if doc.superseded_by_doc_id:
        superseder = db.query(TaxDocument.title, TaxDocument.doc_number).filter(TaxDocument.id == doc.superseded_by_doc_id).first()
        if superseder:
            superseded_by_title = superseder.title
            suppress_external_superseded_by = is_same_legal_reference(doc.superseded_by_doc_number, doc.superseded_by_title, superseder.doc_number, superseder.title)
    result = {'id': doc.id, 'source_url': doc.source_url, 'category_id': doc.category_id, 'doc_number': doc.doc_number, 'title': doc.title, 'issuing_authority': doc.issuing_authority, 'issue_date': doc.issue_date.isoformat() if doc.issue_date else None, 'effective_date': doc.effective_date.isoformat() if doc.effective_date else None, 'content_markdown': doc.content_markdown, 'content_html': doc.content_html, 'file_path': doc.file_path, 'attachments': doc.attachments or [], 'has_attachments': bool(doc.attachments), 'has_images': _has_images(doc), 'processing_status': doc.processing_status, 'content_hash': doc.content_hash, 'is_imported': doc.is_imported, 'knowledge_doc_id': doc.knowledge_doc_id, 'doc_type': doc.doc_type, 'qa_question': doc.qa_question, 'qa_answer': doc.qa_answer, 'source_id': doc.source_id, 'region_code': doc.region_code, 'interpretation_form': doc.interpretation_form, 'created_at': doc.created_at.isoformat() if doc.created_at else None, 'updated_at': doc.updated_at.isoformat() if doc.updated_at else None, 'doc_status': doc.doc_status, 'superseded_by_doc_id': doc.superseded_by_doc_id, 'superseded_by_title': superseded_by_title, 'superseded_by_doc_number': None if suppress_external_superseded_by else doc.superseded_by_doc_number, 'superseded_by_ext_title': None if suppress_external_superseded_by else doc.superseded_by_title, 'superseded_by_source_url': None if suppress_external_superseded_by else doc.superseded_by_source_url, 'version_number': None, 'inline_images': doc.inline_images, 'inline_videos': doc.inline_videos, 'supersedes': doc.supersedes or [], 'references': _resolve_references(db, doc.references or []), 'timeline': _build_timeline(db, doc_id)}
    if doc.is_imported and doc.knowledge_doc_id:
        try:
            client = KnowledgeBaseClient()
            kb_resp = await client.get_document_status(doc.knowledge_doc_id)
            if kb_resp and kb_resp.get('success') and kb_resp.get('data'):
                kb_data = kb_resp['data']
                result['doc_status'] = kb_data.get('doc_status')
                result['superseded_by_doc_id'] = kb_data.get('superseded_by_doc_id')
                result['version_number'] = kb_data.get('version_number')
        except Exception:
            pass
    return result

@router.get('/{doc_id}/attachments/{att_index}/download')
async def download_attachment(doc_id: int, att_index: int, db: Session=Depends(get_db)):
    doc = db.query(TaxDocument).filter(TaxDocument.id == doc_id).first()
    if not doc:
        raise HTTPException(status_code=404, detail='文档不存在')
    attachments = doc.attachments or []
    if att_index < 0 or att_index >= len(attachments):
        raise HTTPException(status_code=404, detail='附件不存在')
    att = attachments[att_index]
    local_path = att.get('path', '')
    if not local_path:
        raise HTTPException(status_code=404, detail='本地文件不存在，请重新下载')
    abs_path = Path(settings.project_root) / local_path
    try:
        is_file = abs_path.is_file()
    except OSError:
        is_file = False
    if not is_file:
        raise HTTPException(status_code=404, detail='本地文件不存在，请重新下载')
    return FileResponse(path=str(abs_path), filename=abs_path.name, media_type='application/octet-stream')

@router.get('/{doc_id}/attachments')
async def get_document_attachments(doc_id: int, db: Session=Depends(get_db)):
    doc = db.query(TaxDocument).filter(TaxDocument.id == doc_id).first()
    if not doc:
        raise HTTPException(status_code=404, detail='文档不存在')
    return {'doc_id': doc_id, 'attachments': doc.attachments or []}

@router.get('/{doc_id}/videos/{vid_index}/stream')
async def stream_inline_video(doc_id: int, vid_index: int, db: Session=Depends(get_db)):
    doc = db.query(TaxDocument).filter(TaxDocument.id == doc_id).first()
    if not doc:
        raise HTTPException(status_code=404, detail='文档不存在')
    videos = doc.inline_videos or []
    if vid_index >= len(videos):
        raise HTTPException(status_code=404, detail='视频不存在')
    local_path = videos[vid_index].get('path', '')
    if not local_path:
        raise HTTPException(status_code=404, detail='视频文件路径为空')
    abs_path = Path(local_path) if Path(local_path).is_absolute() else Path(settings.project_root) / local_path
    if not abs_path.is_file():
        raise HTTPException(status_code=404, detail='视频文件不存在')
    suffix = abs_path.suffix.lower()
    media_type = 'video/mp4' if suffix == '.mp4' else 'video/webm' if suffix == '.webm' else 'application/octet-stream'
    return FileResponse(path=str(abs_path), media_type=media_type)

@router.post('/{doc_id}/reprocess')
async def reprocess_document(doc_id: int, db: Session=Depends(get_db)):
    doc = db.query(TaxDocument).filter(TaxDocument.id == doc_id).first()
    if not doc:
        raise HTTPException(status_code=404, detail='文档不存在')
    task = reprocess_from_local_task.delay(doc_id)
    return {'message': '文档重处理任务已启动', 'doc_id': doc_id, 'source_url': doc.source_url, 'category_id': doc.category_id, 'from_local': bool(doc.content_html), 'celery_task_id': task.id}

@router.post('/{doc_id}/import')
async def import_document(doc_id: int, db: Session=Depends(get_db)):
    doc = db.query(TaxDocument).filter(TaxDocument.id == doc_id).first()
    if not doc:
        raise HTTPException(status_code=404, detail='文档不存在')
    if doc.is_imported:
        return {'message': '文档已导入', 'doc_id': doc_id, 'knowledge_doc_id': doc.knowledge_doc_id}
    task = import_to_knowledge_base_task.delay(doc_id)
    return {'message': '文档导入任务已启动', 'doc_id': doc_id, 'celery_task_id': task.id}

@router.post('/batch-import')
async def batch_import_documents(request: BatchImportRequest=Body(...), db: Session=Depends(get_db)):
    valid_docs = db.query(TaxDocument.id).filter(TaxDocument.id.in_(request.doc_ids), TaxDocument.is_imported.is_(False)).all()
    valid_doc_ids = [row[0] for row in valid_docs]
    celery_task_ids = []
    for doc_id in valid_doc_ids:
        task = import_to_knowledge_base_task.delay(doc_id)
        celery_task_ids.append(task.id)
    return {'message': f'批量导入任务已启动，共 {len(valid_doc_ids)} 个文档', 'requested': len(request.doc_ids), 'queued': len(valid_doc_ids), 'skipped_already_imported': len(request.doc_ids) - len(valid_doc_ids), 'celery_task_ids': celery_task_ids}

class UpdateContentRequest(BaseModel):
    content_markdown: str

@router.put('/{doc_id}/content')
async def update_document_content(doc_id: int, body: UpdateContentRequest, db: Session=Depends(get_db)):
    doc = db.query(TaxDocument).filter(TaxDocument.id == doc_id).first()
    if not doc:
        raise HTTPException(status_code=404, detail='文档不存在')
    last_ver = db.query(func.max(TaxDocumentVersion.version_number)).filter(TaxDocumentVersion.doc_id == doc_id).scalar() or 0
    new_version = last_ver + 1
    version = TaxDocumentVersion(doc_id=doc_id, version_number=new_version, content_markdown=body.content_markdown)
    db.add(version)
    doc.content_markdown = body.content_markdown
    if doc.is_imported:
        doc.is_imported = False
        doc.knowledge_doc_id = None
    db.commit()
    return {'message': '保存成功', 'doc_id': doc_id, 'version_number': new_version}

@router.get('/{doc_id}/versions')
async def list_document_versions(doc_id: int, db: Session=Depends(get_db)):
    doc = db.query(TaxDocument).filter(TaxDocument.id == doc_id).first()
    if not doc:
        raise HTTPException(status_code=404, detail='文档不存在')
    versions = db.query(TaxDocumentVersion).filter(TaxDocumentVersion.doc_id == doc_id).order_by(TaxDocumentVersion.version_number.desc()).all()
    return {'doc_id': doc_id, 'versions': [{'id': v.id, 'version_number': v.version_number, 'created_at': v.created_at.isoformat() if v.created_at else None} for v in versions]}

@router.get('/{doc_id}/versions/{version_number}')
async def get_document_version(doc_id: int, version_number: int, db: Session=Depends(get_db)):
    ver = db.query(TaxDocumentVersion).filter(TaxDocumentVersion.doc_id == doc_id, TaxDocumentVersion.version_number == version_number).first()
    if not ver:
        raise HTTPException(status_code=404, detail='版本不存在')
    return {'doc_id': doc_id, 'version_number': ver.version_number, 'content_markdown': ver.content_markdown, 'created_at': ver.created_at.isoformat() if ver.created_at else None}

@router.get('/{doc_id}/images/{img_index}')
async def get_document_image(doc_id: int, img_index: int, db: Session=Depends(get_db)):
    doc = db.query(TaxDocument).filter(TaxDocument.id == doc_id).first()
    if not doc:
        raise HTTPException(status_code=404, detail='文档不存在')
    images = doc.inline_images or []
    if img_index < 0 or img_index >= len(images):
        raise HTTPException(status_code=404, detail='图片不存在')
    local_path = images[img_index].get('path', '')
    if not local_path:
        raise HTTPException(status_code=404, detail='图片路径为空')
    abs_path = Path(settings.project_root) / local_path
    if not abs_path.is_file():
        raise HTTPException(status_code=404, detail='图片文件不存在')
    suffix = abs_path.suffix.lower()
    media_type = {'jpg': 'image/jpeg', 'jpeg': 'image/jpeg', 'png': 'image/png', 'gif': 'image/gif', 'webp': 'image/webp'}.get(suffix.lstrip('.'), 'image/jpeg')
    return FileResponse(path=str(abs_path), media_type=media_type)
