import httpx

from app.celery_app import celery_app
from app.config import get_settings
from app.tasks.processor_tasks import DatabaseTask
from common_logging import get_logger
logger = get_logger(__name__)

settings = get_settings()
VL_BASE_URL = 'http://localhost:8400/v1'
SWITCH_MODE_URL = f'{settings.base_platform_url}/internal/switch_mode'

def _switch_mode(mode: str) -> None:
    import time
    token = getattr(settings, 'internal_api_token', None) or ''
    httpx.post(SWITCH_MODE_URL, json={'mode': mode}, headers={'X-Internal-Token': token}, timeout=30).raise_for_status()
    logger.info(f'[ocr_tasks] 模式切换已触发: {mode}，等待就绪...')
    try:
        import redis as _redis
        r = _redis.Redis(host='localhost', port=6379, db=0, decode_responses=True)
        for _ in range(120):
            current = r.get('llm:service:status')
            if current == mode:
                logger.info(f'[ocr_tasks] 已切换到模式: {mode}')
                return
            time.sleep(5)
        logger.warning(f"[ocr_tasks] 等待模式切换超时，当前状态: {r.get('llm:service:status')}")
    except Exception as e:
        logger.warning(f'[ocr_tasks] 无法轮询 Redis，直接继续: {e}')

def _ocr_image(local_path: str) -> str:
    prompt = '提取图片中所有文字内容，表格用Markdown格式输出，保持原始结构，不要添加任何解释。'
    resp = httpx.post(f'{VL_BASE_URL}/chat/completions', json={'model': 'Qwen3-VL-32B-Instruct', 'messages': [{'role': 'user', 'content': [{'type': 'image_url', 'image_url': {'url': f'file://{local_path}'}}, {'type': 'text', 'text': prompt}]}], 'max_tokens': 2048}, timeout=120)
    return resp.json()['choices'][0]['message']['content'].strip()

def _rebuild_content_text(doc, cleaner) -> None:
    from bs4 import BeautifulSoup
    if not doc.content_html:
        return
    soup = BeautifulSoup(doc.content_html, 'lxml')
    for img_info in doc.inline_images or []:
        if not img_info.get('ocr_text'):
            continue
        tag = soup.find('img', src=lambda s, src=img_info['src_original']: s and src in s)
        if tag:
            tag.insert_after(BeautifulSoup(f"""<p class="ocr-text">{img_info['ocr_text']}</p>""", 'html.parser'))
    doc.content_text = cleaner.markdown_to_rag_text(cleaner.html_to_markdown(str(soup))).replace('\x00', '')

@celery_app.task(bind=True, base=DatabaseTask, name='app.tasks.ocr_tasks.run_media_batch')
def run_media_batch(self):
    from sqlalchemy.orm.attributes import flag_modified

    from app.models.tax_data import TaxDocument
    from app.services.tax_data_processor.document_cleaner import DocumentCleaner
    db = self.db
    docs = db.query(TaxDocument).filter(TaxDocument.inline_images.isnot(None)).all()
    pending = [d for d in docs if any(not img.get('ocr_text') and img.get('path') for img in d.inline_images or [])]
    if not pending:
        logger.info('[ocr_tasks] 无待处理文档，跳过')
        return {'success': True, 'processed': 0}
    logger.info(f'[ocr_tasks] 待处理文档: {len(pending)} 条')
    cleaner = DocumentCleaner()
    _switch_mode('media_processing')
    processed = 0
    try:
        for doc in pending:
            updated = False
            for img in doc.inline_images:
                if img.get('ocr_text') or not img.get('path'):
                    continue
                try:
                    img['ocr_text'] = _ocr_image(img['path'])
                    updated = True
                except Exception as e:
                    logger.warning(f"[ocr_tasks] doc_id={doc.id} OCR 失败: {img['path']} — {e}")
            if updated:
                _rebuild_content_text(doc, cleaner)
                flag_modified(doc, 'inline_images')
                try:
                    db.commit()
                    processed += 1
                    from app.services.multimodal_indexer import index_document_media
                    index_document_media(doc.id, doc.inline_images, doc.inline_videos)
                    if doc.is_imported and doc.knowledge_doc_id:
                        import asyncio as _asyncio

                        from app.services.tax_data_processor.knowledge_base_client import (
                            KnowledgeBaseClient,
                        )


                        _client = KnowledgeBaseClient(base_url=settings.base_platform_url, api_key=getattr(settings, 'base_platform_api_key', None))
                        _asyncio.run(_client.update_document_content(doc.knowledge_doc_id, doc.content_text))
                except Exception as e:
                    db.rollback()
                    logger.error(f'[ocr_tasks] doc_id={doc.id} 写库失败: {e}')
    finally:
        _switch_mode('inference')
    logger.info(f'[ocr_tasks] 完成: processed={processed}/{len(pending)}')
    return {'success': True, 'processed': processed, 'total': len(pending)}
