import json
from typing import Any

from openai import OpenAI

from app.config import settings

from common_logging import get_logger

logger = get_logger(__name__)


ENTITY_EXTRACTION_PROMPT = '你是一个专业的知识图谱构建助手。请从以下文档中提取关键实体和关系。\n\n文档标题：{title}\n文档内容：\n{content}\n\n请按照以下JSON格式返回：\n{{\n  "entities": [\n    {{\n      "name": "实体名称",\n      "type": "PERSON|ORG|LOCATION|CONCEPT|TERM|LAW|DATE|MONEY|PERCENT",\n      "description": "简短描述（可选）",\n      "salience": 0.0-1.0\n    }}\n  ],\n  "relations": [\n    {{\n      "source": "实体1名称",\n      "target": "实体2名称",\n      "type": "RELATED_TO|PART_OF|INSTANCE_OF|DEFINED_BY|REGULATES",\n      "confidence": 0.0-1.0\n    }}\n  ]\n}}\n\n要求：\n1. 只提取重要实体（至少出现2次或具有关键意义）\n2. 每个文档提取5-20个实体\n3. 关系要有明确语义\n4. salience表示实体在文档中的重要性（0-1）\n5. confidence表示关系的置信度（0-1）\n6. 实体类型说明：\n   - PERSON: 人名\n   - ORG: 组织机构\n   - LOCATION: 地点\n   - CONCEPT: 概念术语\n   - TERM: 专业术语\n   - LAW: 法律法规\n   - DATE: 日期时间\n   - MONEY: 金额\n   - PERCENT: 百分比\n\n请只返回JSON，不要包含其他文字。'


class EntityExtractor:

    def __init__(self):
        pass

    def _get_model_config(self, db_session=None):
        if db_session:
            try:
                from app.models.provider import Model, ModelProvider


                model = (
                    db_session.query(Model)
                    .join(ModelProvider)
                    .filter(
                        Model.type == "chat",
                        Model.enabled,
                        ModelProvider.enabled,
                        ModelProvider.configured,
                        ModelProvider.api_key.isnot(None),
                    )
                    .first()
                )
                if model and model.provider:
                    client = OpenAI(
                        api_key=model.provider.get_api_key(),
                        base_url=model.provider.base_url or model.provider.default_base_url,
                    )
                    return (client, model.code)
            except Exception as e:
                logger.warning(f"Failed to get model from database: {e}")
        if settings.OPENAI_API_KEY:
            client = OpenAI(api_key=settings.OPENAI_API_KEY, base_url=settings.OPENAI_BASE_URL)
            return (client, settings.GRAPH_ENTITY_EXTRACTION_MODEL)
        return (None, None)

    def extract_entities(
        self, title: str, content: str, max_length: int = 4000, db_session=None
    ) -> dict[str, list[dict[str, Any]]]:
        client, model_name = self._get_model_config(db_session)
        if not client:
            logger.warning("No model configuration available, skipping entity extraction")
            return {"entities": [], "relations": []}
        if len(content) > max_length:
            content = content[:max_length] + "..."
            logger.info(f"Content truncated to {max_length} characters")
        try:
            prompt = ENTITY_EXTRACTION_PROMPT.format(title=title, content=content)
            response = client.chat.completions.create(
                model=model_name,
                messages=[
                    {
                        "role": "system",
                        "content": "你是一个专业的知识图谱构建助手，擅长从文本中提取实体和关系。",
                    },
                    {"role": "user", "content": prompt},
                ],
                temperature=0.3,
                max_tokens=2000,
                response_format={"type": "json_object"},
            )
            result_text = response.choices[0].message.content
            result = json.loads(result_text)
            entities = self._validate_entities(result.get("entities", []))
            relations = self._validate_relations(result.get("relations", []))
            logger.info(
                f"Extracted {len(entities)} entities and {len(relations)} relations from '{title}'"
            )
            return {"entities": entities, "relations": relations}
        except json.JSONDecodeError as e:
            logger.error(f"Failed to parse LLM response as JSON: {e}")
            return {"entities": [], "relations": []}
        except Exception as e:
            logger.error(f"Entity extraction failed: {e}")
            return {"entities": [], "relations": []}

    def _validate_entities(self, entities: list[dict[str, Any]]) -> list[dict[str, Any]]:
        valid_types = {
            "PERSON",
            "ORG",
            "LOCATION",
            "CONCEPT",
            "TERM",
            "LAW",
            "DATE",
            "MONEY",
            "PERCENT",
        }
        validated = []
        for entity in entities:
            if not entity.get("name") or not entity.get("type"):
                continue
            entity_type = entity["type"].upper()
            if entity_type not in valid_types:
                entity_type = "CONCEPT"
            salience = float(entity.get("salience", 0.5))
            if salience < settings.GRAPH_MIN_ENTITY_CONFIDENCE:
                continue
            validated.append(
                {
                    "name": entity["name"].strip(),
                    "type": entity_type,
                    "description": entity.get("description", ""),
                    "salience": salience,
                }
            )
        return validated

    def _validate_relations(self, relations: list[dict[str, Any]]) -> list[dict[str, Any]]:
        valid_types = {"RELATED_TO", "PART_OF", "INSTANCE_OF", "DEFINED_BY", "REGULATES"}
        validated = []
        for relation in relations:
            if (
                not relation.get("source")
                or not relation.get("target")
                or (not relation.get("type"))
            ):
                continue
            rel_type = relation["type"].upper()
            if rel_type not in valid_types:
                rel_type = "RELATED_TO"
            confidence = float(relation.get("confidence", 0.5))
            if confidence < settings.GRAPH_MIN_ENTITY_CONFIDENCE:
                continue
            validated.append(
                {
                    "source": relation["source"].strip(),
                    "target": relation["target"].strip(),
                    "type": rel_type,
                    "confidence": confidence,
                }
            )
        return validated

    def extract_entities_batch(
        self, documents: list[dict[str, str]], max_concurrent: int = 5
    ) -> list[dict[str, Any]]:
        results = []
        for doc in documents:
            result = self.extract_entities(
                title=doc.get("title", ""), content=doc.get("content", "")
            )
            results.append(
                {
                    "document_id": doc.get("id"),
                    "entities": result["entities"],
                    "relations": result["relations"],
                }
            )
        return results


entity_extractor = EntityExtractor()
