import httpx
from common_logging import get_logger
logger = get_logger(__name__)

EMBEDDING_URL = 'http://localhost:8200/v1/embeddings'
MILVUS_COLLECTION = 'tax_knowledge_multimodal'
EMBEDDING_DIM = 3584

def _embed(text: str=None, image_path: str=None) -> list:
    payload = {'model': 'Qwen3-VL-Embedding-8B'}
    if image_path and text:
        payload['input'] = [{'type': 'image_url', 'image_url': {'url': f'file://{image_path}'}}, {'type': 'text', 'text': text}]
    elif image_path:
        payload['input'] = [{'type': 'image_url', 'image_url': {'url': f'file://{image_path}'}}]
    else:
        payload['input'] = text
    resp = httpx.post(EMBEDDING_URL, json=payload, timeout=60)
    resp.raise_for_status()
    return resp.json()['data'][0]['embedding']

def _get_milvus_collection():
    from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, utility


    if not utility.has_collection(MILVUS_COLLECTION):
        fields = [FieldSchema('id', DataType.INT64, is_primary=True, auto_id=True), FieldSchema('doc_id', DataType.INT64), FieldSchema('chunk_type', DataType.VARCHAR, max_length=20), FieldSchema('chunk_index', DataType.INT64), FieldSchema('content', DataType.VARCHAR, max_length=4096), FieldSchema('media_path', DataType.VARCHAR, max_length=500), FieldSchema('embedding', DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM)]
        col = Collection(MILVUS_COLLECTION, CollectionSchema(fields))
        col.create_index('embedding', {'index_type': 'HNSW', 'metric_type': 'COSINE', 'params': {'M': 16, 'efConstruction': 200}})
        return col
    return Collection(MILVUS_COLLECTION)

def index_document_media(doc_id: int, inline_images: list, inline_videos: list=None) -> dict:
    try:
        col = _get_milvus_collection()
    except Exception as e:
        logger.error(f'[multimodal_indexer] Milvus 连接失败: {e}')
        return {'success': False, 'error': str(e)}
    rows = []
    for idx, img in enumerate(inline_images or []):
        path = img.get('path', '')
        ocr_text = img.get('ocr_text', '')
        if not path:
            continue
        try:
            embedding = _embed(text=ocr_text or None, image_path=path)
            rows.append({'doc_id': doc_id, 'chunk_type': 'image', 'chunk_index': idx, 'content': ocr_text, 'media_path': path, 'embedding': embedding})
        except Exception as e:
            logger.warning(f'[multimodal_indexer] 图片向量化失败 doc_id={doc_id} idx={idx}: {e}')
    for idx, vid in enumerate(inline_videos or []):
        transcript = vid.get('transcript', '')
        if not transcript:
            continue
        try:
            embedding = _embed(text=transcript)
            rows.append({'doc_id': doc_id, 'chunk_type': 'video_asr', 'chunk_index': idx, 'content': transcript, 'media_path': vid.get('path', ''), 'embedding': embedding})
        except Exception as e:
            logger.warning(f'[multimodal_indexer] 视频 ASR 向量化失败 doc_id={doc_id} idx={idx}: {e}')
    if not rows:
        return {'success': True, 'inserted': 0}
    try:
        col.insert([[r['doc_id'] for r in rows], [r['chunk_type'] for r in rows], [r['chunk_index'] for r in rows], [r['content'] for r in rows], [r['media_path'] for r in rows], [r['embedding'] for r in rows]])
        col.flush()
        logger.info(f'[multimodal_indexer] doc_id={doc_id} 写入 {len(rows)} 条')
        return {'success': True, 'inserted': len(rows)}
    except Exception as e:
        logger.error(f'[multimodal_indexer] Milvus 写入失败 doc_id={doc_id}: {e}')
        return {'success': False, 'error': str(e)}
