import asyncio


from app.services.tax_data_processor.knowledge_base_client import KnowledgeBaseClient

from .state import ImportState
from common_logging import get_logger
logger = get_logger(__name__)


def run_phase5(state: ImportState, base_url: str, token: str, source_id: int=None) -> None:
    if state.is_phase_done('5'):
        logger.info('Phase 5 已完成，跳过')
        return
    kb_id = state.get_kb_id()
    if not kb_id:
        raise RuntimeError('kb_id 未找到，请先执行 Phase 1')
    client = KnowledgeBaseClient(base_url=base_url, api_key=token)
    if source_id:
        from sqlalchemy import text

        from app.database import SessionLocal


        with SessionLocal() as db:
            doc_ids = [r[0] for r in db.execute(text('\n                SELECT knowledge_doc_id FROM tax_documents\n                WHERE source_id = :sid AND is_imported = true\n                  AND knowledge_doc_id IS NOT NULL\n            '), {'sid': source_id}).fetchall()]
        result = asyncio.run(client.trigger_graph_build(kb_id=kb_id, document_ids=doc_ids))
    else:
        result = asyncio.run(client.trigger_graph_build(kb_id=kb_id, document_ids=None))
    if result and result.get('success'):
        logger.info(f'✓ Phase 5 完成：知识图谱构建任务已触发 (kb_id={kb_id})')
        state.mark_phase_done('5')
    else:
        error = result.get('error') if result else 'unknown'
        logger.error(f'知识图谱构建触发失败: {error}')
        raise RuntimeError(f'Phase 5 失败: {error}')
