import argparse
import os
import sys
from pathlib import Path

from common_logging import get_logger

logger = get_logger(__name__)

sys.path.insert(0, str(Path(__file__).resolve().parents[4]))
from app.services.import_kb.phase1_init import login, run_phase1
from app.services.import_kb.phase2_documents import run_phase2
from app.services.import_kb.phase2b_relations import run_phase2b
from app.services.import_kb.phase3_vectorize import run_phase3
from app.services.import_kb.phase4_autotag import run_phase4
from app.services.import_kb.phase5_graph import run_phase5
from app.services.import_kb.state import ImportState

ALL_PHASES = ['taxonomy', '1', '2', '2b', '3', '4', '5', 'media']

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description='税法知识库导入工具')
    parser.add_argument('--phase', default='all', help='执行阶段：all | 1 | 2 | 2b | 3 | 4 | 5 | 逗号分隔组合（如 1,2,2b）')
    parser.add_argument('--base-url', default=os.getenv('BASE_PLATFORM_URL', 'http://localhost:8000'))
    parser.add_argument('--email', default=os.getenv('IMPORT_EMAIL', 'admin@hellotax.cn'))
    parser.add_argument('--password', default=os.getenv('IMPORT_PASSWORD', 'Hellotax@2026#Admin'))
    parser.add_argument('--tenant-id', type=int, default=int(os.getenv('IMPORT_TENANT_ID', '0')))
    parser.add_argument('--resume', action='store_true')
    parser.add_argument('--reset-state', action='store_true')
    return parser.parse_args()

def resolve_phases(phase_arg: str) -> list[str]:
    if phase_arg == 'all':
        return ALL_PHASES
    return [p.strip() for p in phase_arg.split(',') if p.strip()]

def main():
    args = parse_args()
    state = ImportState()
    if args.reset_state:
        state.reset()
        logger.info('checkpoint 已清除')
    phases = resolve_phases(args.phase)
    logger.info(f'执行阶段: {phases}')
    logger.info(f'base_url: {args.base_url}, tenant_id: {args.tenant_id}')
    PHASES_NEEDING_TOKEN = {'1', '2', '2b', '3', '4', '5'}
    token = None
    if any(p in PHASES_NEEDING_TOKEN for p in phases):
        token = login(args.base_url, args.email, args.password)
    for phase in phases:
        logger.info(f'━━━ Phase {phase} 开始 ━━━')
        if phase == 'taxonomy':
            import subprocess
            bp_python = str(Path(__file__).resolve().parents[5] / 'base_platform' / 'venv' / 'bin' / 'python')
            tax_scripts = str(Path(__file__).resolve().parents[5] / 'industry_accelerator' / 'tax' / 'scripts')
            subprocess.run([bp_python, '-c', f"import sys; sys.path.insert(0, '{tax_scripts}'); from seed.taxonomy import seed_taxonomy; seed_taxonomy({args.tenant_id})"], check=True)
            logger.info('标签体系导入完成')
        elif phase == '1':
            run_phase1(state, args.base_url, token, args.tenant_id)
        elif phase == '2':
            run_phase2(state, args.base_url, token)
        elif phase == '2b':
            run_phase2b(state, args.base_url, token)
        elif phase == '3':
            run_phase3(state, args.base_url, token)
        elif phase == '4':
            run_phase4(state, args.base_url, token)
        elif phase == '5':
            run_phase5(state, args.base_url, token)
        elif phase == 'media':
            from app.tasks.asr_tasks import run_asr_batch
            from app.tasks.ocr_tasks import run_media_batch
            logger.info('触发 OCR 批处理...')
            run_media_batch()
            logger.info('触发 ASR 批处理...')
            run_asr_batch()
        else:
            logger.warning(f'未知阶段: {phase}，跳过')
    logger.info('✅ 所有阶段执行完毕')
if __name__ == '__main__':
    main()
