import time
from decimal import Decimal
from typing import Any

from document.tax_document_splitter import TaxDocumentSplitter
from tax_calculation_engine import TaxpayerType, get_tax_calculation_engine
from workflows.declaration_workflow import DeclarationWorkflow

from common_logging import get_logger

logger = get_logger(__name__)


def benchmark_tax_calculation() -> dict[str, Any]:
    engine = get_tax_calculation_engine()
    results = {}
    start = time.time()
    for _ in range(100):
        engine.calculate_vat(sales_amount=Decimal('100000'), taxpayer_type=TaxpayerType.GENERAL, rate_type='general_13', input_tax=Decimal('5000'))
    vat_time = (time.time() - start) / 100 * 1000
    results['vat_calculation_ms'] = round(vat_time, 2)
    start = time.time()
    for _ in range(100):
        engine.calculate_corporate_income_tax(taxable_income=Decimal('500000'), rate_type='standard')
    corp_time = (time.time() - start) / 100 * 1000
    results['corporate_income_tax_ms'] = round(corp_time, 2)
    start = time.time()
    for _ in range(100):
        engine.calculate_personal_income_tax(annual_income=Decimal('200000'), special_deductions=Decimal('20000'))
    personal_time = (time.time() - start) / 100 * 1000
    results['personal_income_tax_ms'] = round(personal_time, 2)
    start = time.time()
    for _ in range(100):
        engine.calculate_vat(sales_amount=Decimal('100000'), taxpayer_type=TaxpayerType.GENERAL, rate_type='general_13', input_tax=Decimal('5000'))
    cached_time = (time.time() - start) / 100 * 1000
    results['vat_cached_ms'] = round(cached_time, 2)
    cache_stats = engine.get_cache_stats()
    results['cache_stats'] = cache_stats
    return results

def benchmark_document_splitting() -> dict[str, Any]:
    splitter = TaxDocumentSplitter(granularity='article')
    results = {}
    sample_text = '\n第一章 总则\n\n第一条 为了规范税收征收管理，保障国家税收收入，保护纳税人的合法权益，促进经济和社会发展，制定本法。\n\n第二条 凡依法由税务机关征收的各种税收的征收管理，均适用本法。\n\n第三条 税收的开征、停征以及减税、免税、退税、补税，依照法律的规定执行；法律授权国务院规定的，依照国务院制定的行政法规的规定执行。\n\n第四条 法律、行政法规规定负有纳税义务的单位和个人为纳税人。\n\n第五条 国务院税务主管部门主管全国税收征收管理工作。各地国家税务局和地方税务局应当按照国务院规定的税收征收管理范围分别进行征收管理。\n' * 10
    start = time.time()
    chunks = splitter.split_text(sample_text)
    split_time = (time.time() - start) * 1000
    results['document_splitting_ms'] = round(split_time, 2)
    results['chunks_count'] = len(chunks)
    results['text_length'] = len(sample_text)
    texts = [sample_text] * 5
    start = time.time()
    splitter.split_texts_batch(texts)
    batch_time = (time.time() - start) * 1000
    results['batch_splitting_ms'] = round(batch_time, 2)
    results['batch_size'] = len(texts)
    return results

def benchmark_workflow_state() -> dict[str, Any]:
    workflow = DeclarationWorkflow('wf_001', tenant_id=1, user_id=1)
    results = {}
    start = time.time()
    workflow.get_status()
    first_time = (time.time() - start) * 1000
    results['first_status_get_ms'] = round(first_time, 2)
    start = time.time()
    for _ in range(100):
        workflow.get_status()
    cached_time = (time.time() - start) / 100 * 1000
    results['cached_status_get_ms'] = round(cached_time, 2)
    start = time.time()
    workflow.execute_step('data_collection', {'data': 'test'})
    exec_time = (time.time() - start) * 1000
    results['step_execution_ms'] = round(exec_time, 2)
    return results

def run_all_benchmarks() -> dict[str, Any]:
    print('=' * 60)
    print('税务行业包性能基准测试')
    print('=' * 60)
    results = {}
    print('\n[1/3] 税务计算性能测试...')
    calc_results = benchmark_tax_calculation()
    results['tax_calculation'] = calc_results
    print(f"  增值税计算: {calc_results['vat_calculation_ms']}ms (目标: <100ms)")
    print(f"  企业所得税: {calc_results['corporate_income_tax_ms']}ms (目标: <100ms)")
    print(f"  个人所得税: {calc_results['personal_income_tax_ms']}ms (目标: <100ms)")
    print(f"  缓存命中: {calc_results['vat_cached_ms']}ms")
    print(f"  缓存大小: {calc_results['cache_stats']['cache_size']}/{calc_results['cache_stats']['cache_max_size']}")
    print('\n[2/3] 文档分割性能测试...')
    split_results = benchmark_document_splitting()
    results['document_splitting'] = split_results
    print(f"  10页文档分割: {split_results['document_splitting_ms']}ms (目标: <2000ms)")
    print(f"  生成块数: {split_results['chunks_count']}")
    print(f"  批处理(5文档): {split_results['batch_splitting_ms']}ms")
    print('\n[3/3] 工作流状态性能测试...')
    workflow_results = benchmark_workflow_state()
    results['workflow'] = workflow_results
    print(f"  首次状态获取: {workflow_results['first_status_get_ms']}ms")
    print(f"  缓存状态获取: {workflow_results['cached_status_get_ms']}ms (目标: <500ms)")
    print(f"  步骤执行: {workflow_results['step_execution_ms']}ms (目标: <500ms)")
    print('\n' + '=' * 60)
    print('性能测试总结')
    print('=' * 60)
    passed = 0
    failed = 0
    checks = [('税务计算 < 100ms', calc_results['vat_calculation_ms'] < 100), ('文档分割 < 2000ms', split_results['document_splitting_ms'] < 2000), ('工作流状态 < 500ms', workflow_results['cached_status_get_ms'] < 500)]
    for check_name, passed_check in checks:
        status = '✓ PASS' if passed_check else '✗ FAIL'
        print(f'  {status}: {check_name}')
        if passed_check:
            passed += 1
        else:
            failed += 1
    print(f'\n总计: {passed} 通过, {failed} 失败')
    results['summary'] = {'passed': passed, 'failed': failed}
    logger.bind(test_name="all_benchmarks").info("benchmark completed", passed=passed, failed=failed)
    return results
if __name__ == '__main__':
    results = run_all_benchmarks()
