import asyncio
import json
import os as _os

from fastapi import (
    APIRouter,
    Depends,
    HTTPException,
    Request,
    WebSocket,
    WebSocketDisconnect,
    status,
)
from fastapi.responses import StreamingResponse
from common_logging import get_logger
from sqlalchemy.orm import Session

with open(
    _os.path.join(_os.path.dirname(__file__), "default_quick_actions.json"), encoding="utf-8"
) as _f:
    DEFAULT_QUICK_ACTIONS = json.load(_f)

from pydantic import BaseModel

from app import crud
from app.api.deps import get_db
from app.api.permissions import (
    require_create,
    require_delete,
    require_execute,
    require_read,
    require_update,
)
from app.core.exceptions import ExternalServiceError
from app.core.i18n import get_translator
from app.db.session import SessionLocal
from app.models import Agent, ChatMessage, Model, User
from app.schemas import AgentCreate, AgentResponse, AgentUpdate
from app.services.agent.conversation_manager import conversation_manager
from app.services.llm.backends.chat_backend_factory import get_chat_factory

logger = get_logger(__name__)
router = APIRouter(tags=["agents"])


class ChatRequest(BaseModel):
    message: str
    conversation_id: str | None = None
    stream: bool = False


class ChatResponse(BaseModel):
    response: str
    conversation_id: str | None = None
    sources: list[dict] | None = []
    format: str | None = "text"
    table_data: dict | None = None
    metadata: dict | None = None
    format: str | None = "text"
    table_data: dict | None = None
    metadata: dict | None = None


def _sse_event(payload: dict) -> str:
    return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"


def _build_response_style_instruction() -> str:
    return '【财税智能助手行为准则】\n\n一、敏感话题处理规则\n当用户问题涉及以下内容时，请回复"此问题无法回复"：\n- 政治敏感话题\n- 国家安全\n- 暴力与犯罪\n- 色情低俗\n- 身份歧视\n- 群体攻击\n\n二、地域性边界（Local Context）\n税法虽有国家大法，但各地的"税收优惠政策"、"核定征收标准"、"社保缴纳基数"差异极大。\n- 当用户问"小微企业现在要交多少税"等问题时，约束你**不要盲目直接给数字**\n- 必须先反问："请问您的企业注册在哪个省份/城市？属于哪种行业？"\n- 收集足够的地域信息后再提供针对性建议\n\n三、数据脱敏与隐私边界（Data Security）\n财务数据是企业的最高机密。\n- 当看到敏感格式数据（如公司具体名称、张三、某某科技有限公司等）时，进行拦截并提示用户\n- 建议用户将"张三"、"某某科技有限公司"等替换为"客户A"、"企业B"后再提供咨询\n- 不要要求用户提供具体的财务机密数据\n\n四、政策引用要求\n凡涉及税率、税收优惠、扣除标准的回答，**必须**尽可能引用具体的政策依据（如：《财政部 税务总局公告202X年第X号》）。\n- 严禁对政策进行过度解读\n- 如遇政策模糊地带，需明确指出"该问题在实操中存在争议，需与主管税务机关\'单笔单批\'沟通"\n\n五、领域与合规红线\n- **仅限财税与商业合规领域**。对非本领域问题坚决拒绝，话术："作为企业财税专家，我仅专注于财务合规、税务筹划及审计相关事务，无法为您解答[无关领域]的问题。"\n- **拒绝违法操作**。对于"买卖发票"、"隐匿收入"等诉求，直接拒绝并严肃提示相关刑法风险（如虚开增值税专用发票罪）。\n\n六、回答风格要求\n1. 先直接给出简短结论，再用自然语言补充说明，避免机械列出 1、2、3 点\n2. 优先自然表达，避免大段模板化标题或生硬分段\n3. 仅在强调关键词时使用少量 Markdown 粗体，不要堆砌 ** 或连续标题\n4. 除非用户明确要求详细解释，否则回答保持简洁、清晰、可直接阅读\n5. 不要机械追加"参考依据""参考来源"等固定小节；来源会在独立引用区域展示\n6. 如果资料不足，直接说明不确定点，不要编造内容\n7. 如果用户要求"表格/对比/对照"，请用 Markdown 表格输出（包含表头和分隔行），不能仅仅输出表格，其他输出也是必要的。\n\n七、法条时效性处理（当前时间点用户问到再回答，没有问不要主动提时间准确）\n检索结果中若包含同一法条在不同时间的不同处理方式，你必须且只能输出发布时间最新的实操指导，并向用户注明该实操的具体年份来源。\n\n八、强制免责声明\n**所有输出必须以以下内容结尾：**\n\nAI生成的财税建议仅供参考，不作为最终法律、税务处理依据，重大财税决策请咨询专业税务师。'


def _compose_system_prompt(base_prompt: str | None) -> str | None:
    from datetime import datetime

    now = datetime.now()
    time_injection = f'系统当前真实时间是：{now.year}年{now.month}月{now.day}日。当用户提到"今年"、"当前"、"最近"、"现在"等时间词时，你必须且只能基于 {now.year}年{now.month}月 的时间点进行计算和推理。'
    prompt_parts = [time_injection]
    if base_prompt and base_prompt.strip():
        prompt_parts.append(base_prompt.strip())
    prompt_parts.append(_build_response_style_instruction())
    combined_prompt = "\n\n".join(prompt_parts).strip()
    return combined_prompt or None


DISCLAIMER = (
    "\n\nAI生成的财税建议仅供参考，不作为最终法律、税务处理依据。重大财税决策请咨询专业税务师"
)


def _append_rag_reference_block(response_text: str, rag_sources: list[dict] | None) -> str:
    if "AI生成的财税建议仅供参考" in response_text:
        return response_text
    return response_text + DISCLAIMER


def _is_table_request(user_message: str) -> bool:
    normalized_message = (user_message or "").lower()
    table_keywords = [
        "表格",
        "对比",
        "对照",
        "比较",
        "区别",
        "差异",
        "一览表",
        "汇总表",
        "compare",
        "comparison",
        "versus",
        "vs",
        "table",
    ]
    return any(keyword in normalized_message for keyword in table_keywords)


def _rewrite_user_message_for_format(user_message: str) -> str:
    if not _is_table_request(user_message):
        return user_message
    return f"{user_message.strip()}\n\n请直接使用 Markdown 表格回答，必须包含：\n1. 表头行\n2. 分隔行（例如 | --- | --- |）\n3. 至少 2 行对比内容\n除非用户明确要求，否则不要改成普通段落或项目列表。若需要补充说明，请放在表格后面，用一句话简短补充。"


def _normalize_comparison_text(value: str) -> str:
    import re

    cleaned = (value or "").strip()
    cleaned = re.sub("^\\s*(?:#+\\s*)?", "", cleaned)
    cleaned = re.sub("^\\s*(?:[-*•]|\\d+[.)、])\\s*", "", cleaned)
    cleaned = cleaned.strip().strip("*").strip("_").strip()
    cleaned = re.sub("\\s+", " ", cleaned)
    return cleaned


def _parse_heading_candidate(line: str) -> str | None:
    import re

    candidate = _normalize_comparison_text(line).rstrip("：:").strip()
    if not candidate or len(candidate) > 20:
        return None
    if re.search("[。；，,.!?？！]", candidate):
        return None
    if re.search("[:：]", candidate):
        return None
    return candidate


def _parse_label_value(line: str) -> tuple[str, str] | None:
    import re

    candidate = _normalize_comparison_text(line)
    match = re.match("^([^:：]{1,20})\\s*[:：]\\s*(.+)$", candidate)
    if not match:
        return None
    label = match.group(1).strip()
    value = match.group(2).strip()
    if not label or not value:
        return None
    if re.search("[。；，,.!?？！]", label):
        return None
    return (label, value)


def _parse_dimension_heading(line: str) -> str | None:
    import re

    stripped = (line or "").strip()
    if not re.match("^\\s*(?:\\d+[.)、]|[-*•])\\s+", stripped):
        return None
    candidate = _normalize_comparison_text(stripped)
    if re.search("[:：].+\\S", candidate) and (not candidate.endswith((":", "："))):
        return None
    candidate = candidate.rstrip("：:").strip()
    if not candidate or len(candidate) > 20:
        return None
    if re.search("[。；，,.!?？！]", candidate):
        return None
    return candidate


def _detect_dimension_comparison_table(text: str) -> dict | None:
    entity_order = []
    row_entries = []
    current_dimension = None
    current_values = {}
    for raw_line in text.splitlines():
        stripped = raw_line.strip()
        if not stripped:
            continue
        dimension_heading = _parse_dimension_heading(stripped)
        if dimension_heading:
            if current_dimension and len(current_values) >= 2:
                row_entries.append((current_dimension, current_values))
            current_dimension = dimension_heading
            current_values = {}
            continue
        if not current_dimension:
            continue
        label_value = _parse_label_value(stripped)
        if not label_value:
            continue
        label, value = label_value
        if label not in entity_order:
            entity_order.append(label)
        current_values[label] = value
    if current_dimension and len(current_values) >= 2:
        row_entries.append((current_dimension, current_values))
    if len(entity_order) < 2 or len(row_entries) < 2:
        return None
    rows = []
    for dimension, values in row_entries:
        row = [dimension] + [values.get(entity, "") for entity in entity_order]
        if sum(1 for cell in row[1:] if cell) >= 2:
            rows.append(row)
    if len(rows) < 2:
        return None
    return {"headers": ["对比项", *entity_order], "rows": rows}


def _next_nonempty_line(lines: list[str], start_index: int) -> str:
    for idx in range(start_index + 1, len(lines)):
        candidate = lines[idx].strip()
        if candidate:
            return candidate
    return ""


def _detect_section_comparison_table(text: str) -> dict | None:
    import re

    lines = text.splitlines()
    sections = []
    current_section = None
    for idx, raw_line in enumerate(lines):
        stripped = raw_line.strip()
        if not stripped:
            continue
        heading = _parse_heading_candidate(stripped)
        next_line = _next_nonempty_line(lines, idx)
        explicit_heading = stripped.startswith(("#", "**", "__"))
        next_is_list_item = bool(re.match("^\\s*(?:[-*•]|\\d+[.)、])\\s+", next_line))
        if (
            heading
            and (not _parse_label_value(stripped))
            and (explicit_heading or next_is_list_item)
        ):
            if current_section and current_section["items"]:
                sections.append(current_section)
            current_section = {"title": heading, "items": []}
            continue
        if not current_section:
            continue
        label_value = _parse_label_value(stripped)
        if label_value:
            current_section["items"].append(label_value)
            continue
        if re.match("^\\s*(?:[-*•]|\\d+[.)、])\\s+", stripped):
            bullet_text = _normalize_comparison_text(stripped)
            if bullet_text:
                current_section["items"].append(
                    (f"要点{len(current_section['items']) + 1}", bullet_text)
                )
    if current_section and current_section["items"]:
        sections.append(current_section)
    if len(sections) < 2:
        return None
    headers = ["对比项"] + [section["title"] for section in sections]
    row_keys = []
    seen_keys = set()
    for section in sections:
        for key, _ in section["items"]:
            if key not in seen_keys:
                seen_keys.add(key)
                row_keys.append(key)
    if len(row_keys) < 2:
        return None
    rows = []
    for key in row_keys:
        values = []
        present_count = 0
        for section in sections:
            value = next(
                (item_value for item_key, item_value in section["items"] if item_key == key), ""
            )
            if value:
                present_count += 1
            values.append(value)
        if present_count >= 2:
            rows.append([key, *values])
    if len(rows) < 2:
        return None
    return {"headers": headers, "rows": rows}


def _detect_response_format(
    text: str, user_message: str | None = None
) -> tuple[str, dict | None]:
    import re

    lines = text.split("\n")
    table_start = -1
    separator_line = -1
    for i, line in enumerate(lines):
        line = line.strip()
        if not line:
            continue
        if "|" in line:
            if table_start == -1:
                table_start = i
            if re.match("^\\s*\\|[\\s:-]+\\|\\s*$", line) or re.match(
                "^\\s*\\|[\\s:-]+\\|[\\s:-]+\\|", line
            ):
                separator_line = i
                break
    if table_start != -1 and separator_line != -1 and (separator_line == table_start + 1):
        header_line = lines[table_start].strip()
        headers = [h.strip().strip("*") for h in header_line.split("|") if h.strip()]
        rows = []
        for i in range(separator_line + 1, len(lines)):
            line = lines[i].strip()
            if not line or "|" not in line:
                break
            cells = [c.strip().strip("*") for c in line.split("|") if c.strip()]
            if cells and len(cells) >= len(headers):
                rows.append(cells[: len(headers)])
        if headers and rows:
            logger.info(f"Detected table format: {len(headers)} columns, {len(rows)} rows")
            return ("table", {"headers": headers, "rows": rows})
    if _is_table_request(user_message or ""):
        comparison_table = _detect_dimension_comparison_table(
            text
        ) or _detect_section_comparison_table(text)
        if comparison_table:
            logger.info(
                f"Detected comparison fallback table: {len(comparison_table['headers'])} columns, {len(comparison_table['rows'])} rows"
            )
            return ("table", comparison_table)
    list_pattern = "^\\d+\\.\\s+.+$"
    if re.search(list_pattern, text, re.MULTILINE):
        return ("markdown", None)
    return ("text", None)


@router.get("/{agent_id}/conversations", status_code=status.HTTP_200_OK)
def get_agent_conversations(
    agent_id: int,
    skip: int = 0,
    limit: int = 20,
    request: Request = None,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("agents")),
):
    t = get_translator(request)
    db_agent = db.query(Agent).filter(Agent.id == agent_id).first()
    if not db_agent:
        raise HTTPException(status_code=404, detail=t.t("agent.agent_not_found"))
    from sqlalchemy import desc, func

    conversations = (
        db.query(
            ChatMessage.conversation_id,
            func.max(ChatMessage.updated_at).label("last_message_time"),
            func.count(ChatMessage.id).label("message_count"),
            func.max(ChatMessage.content).label("last_message"),
        )
        .filter(ChatMessage.agent_id == agent_id)
        .group_by(ChatMessage.conversation_id)
        .order_by(desc("last_message_time"))
        .offset(skip)
        .limit(limit)
        .all()
    )
    result = []
    for conv in conversations:
        result.append(
            {
                "conversation_id": conv.conversation_id,
                "last_message_time": (
                    conv.last_message_time.isoformat() if conv.last_message_time else None
                ),
                "message_count": conv.message_count,
                "last_message": (
                    conv.last_message[:100] + "..."
                    if len(conv.last_message) > 100
                    else conv.last_message
                ),
            }
        )
    return {"conversations": result, "total": len(result), "skip": skip, "limit": limit}


@router.get("/{agent_id}/conversations/{conversation_id}/messages", status_code=status.HTTP_200_OK)
def get_conversation_messages(
    agent_id: int,
    conversation_id: str,
    before: int = None,
    limit: int = 20,
    request: Request = None,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("agents")),
):
    t = get_translator(request)
    db_agent = db.query(Agent).filter(Agent.id == agent_id).first()
    if not db_agent:
        raise HTTPException(status_code=404, detail=t.t("agent.agent_not_found"))
    query = db.query(ChatMessage).filter(
        ChatMessage.agent_id == agent_id, ChatMessage.conversation_id == conversation_id
    )
    if before:
        query = query.filter(ChatMessage.id < before)
    messages = query.order_by(ChatMessage.created_at.desc()).limit(limit).all()
    logger.info(
        f"查询对话消息: agent_id={agent_id}, conversation_id={conversation_id}, before={before}, limit={limit}"
    )
    logger.info(f"找到 {len(messages)} 条消息")
    messages.reverse()
    result = []
    for msg in messages:
        result.append(
            {
                "id": msg.id,
                "role": msg.role,
                "content": msg.content,
                "sources": msg.sources or [],
                "timestamp": msg.created_at.isoformat(),
                "status": "read",
            }
        )
    return {
        "conversation_id": conversation_id,
        "messages": result,
        "total": len(result),
        "has_more": len(result) == limit,
    }


@router.get("/", response_model=list[AgentResponse])
def get_agents(
    skip: int = 0,
    limit: int = 100,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("agents")),
):
    agents = crud.agent.get_multi(db, skip=skip, limit=limit, current_user=current_user)
    return agents


@router.get("/{agent_id}", response_model=AgentResponse)
def get_agent(
    agent_id: int,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("agents")),
):
    t = get_translator(request)
    agent = crud.agent.get(db, id=agent_id, current_user=current_user)
    if not agent:
        raise HTTPException(status_code=404, detail=t.t("agent.agent_not_found"))
    return agent


@router.post("/", response_model=AgentResponse, status_code=status.HTTP_201_CREATED)
def create_agent(
    agent: AgentCreate,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_create("agents")),
):
    t = get_translator(request)
    try:
        if agent.model_id:
            db_model = db.query(Model).filter(Model.id == agent.model_id).first()
            if not db_model:
                raise HTTPException(status_code=400, detail=t.t("agent.model_not_found"))
        if not agent.quick_actions:
            agent.quick_actions = DEFAULT_QUICK_ACTIONS
        db_agent = crud.agent.create(
            db, obj_in=agent, created_by=current_user.id, tenant_id=current_user.tenant_id
        )
        return db_agent
    except Exception as e:
        db.rollback()
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=t.t("agent.create_failed", error=str(e)),
        ) from e


@router.put("/{agent_id}", response_model=AgentResponse)
def update_agent(
    agent_id: int,
    agent: AgentUpdate,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_update("agents")),
):
    t = get_translator(request)
    db_agent = crud.agent.get(db, id=agent_id, current_user=current_user)
    if not db_agent:
        raise HTTPException(status_code=404, detail=t.t("agent.agent_not_found"))
    try:
        if agent.model_id:
            db_model = db.query(Model).filter(Model.id == agent.model_id).first()
            if not db_model:
                raise HTTPException(status_code=400, detail=t.t("agent.model_not_found"))
        db_agent = crud.agent.update(db, db_obj=db_agent, obj_in=agent)
        return db_agent
    except Exception as e:
        db.rollback()
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=t.t("agent.update_failed", error=str(e)),
        ) from e


@router.delete("/{agent_id}", status_code=status.HTTP_200_OK)
def delete_agent(
    agent_id: int,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_delete("agents")),
):
    t = get_translator(request)
    try:
        deleted_agent = crud.agent.delete(db, id=agent_id)
        if not deleted_agent:
            raise HTTPException(status_code=404, detail=t.t("agent.agent_not_found"))
        return {"success": True, "message": t.t("agent.agent_deleted")}
    except HTTPException:
        raise
    except Exception as e:
        db.rollback()
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=t.t("agent.delete_failed", error=str(e)),
        ) from e


@router.post("/{agent_id}/toggle", status_code=status.HTTP_200_OK)
def toggle_agent(
    agent_id: int,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_update("agents")),
):
    t = get_translator(request)
    db_agent = db.query(Agent).filter(Agent.id == agent_id).first()
    if not db_agent:
        raise HTTPException(status_code=404, detail=t.t("agent.agent_not_found"))
    if current_user.role != "platform_admin":
        if db_agent.tenant_id != current_user.tenant_id:
            raise HTTPException(status_code=403, detail=t.t("agent.access_denied"))
    try:
        db_agent.running = not db_agent.running
        db.commit()
        db.refresh(db_agent)
        return {"success": True, "running": db_agent.running}
    except Exception as e:
        db.rollback()
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=t.t("agent.toggle_failed", error=str(e)),
        ) from e


@router.patch("/{agent_id}/status", status_code=status.HTTP_200_OK)
def update_agent_status(
    agent_id: int,
    status: str,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_update("agents")),
):
    t = get_translator(request)
    if status not in ["online", "offline"]:
        raise HTTPException(status_code=400, detail=t.t("agent.invalid_status"))
    try:
        db_agent = crud.agent.update_status(db, agent_id=agent_id, status=status)
        if not db_agent:
            raise HTTPException(status_code=404, detail=t.t("agent.agent_not_found"))
        if current_user.role != "platform_admin":
            if db_agent.tenant_id != current_user.tenant_id:
                raise HTTPException(status_code=403, detail=t.t("agent.access_denied"))
        return {"success": True, "status": db_agent.status}
    except HTTPException:
        raise
    except Exception as e:
        db.rollback()
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=t.t("agent.update_status_failed", error=str(e)),
        ) from e


@router.post("/{agent_id}/chat", response_model=ChatResponse)
async def chat_with_agent(
    agent_id: int,
    chat_request: ChatRequest,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_execute("agents")),
):
    t = get_translator(request)
    current_user_id = current_user.id
    logger.info(f"User {current_user.name} (ID: {current_user.id}) initiated chat")
    db_agent = db.query(Agent).filter(Agent.id == agent_id).first()
    if not db_agent:
        raise HTTPException(status_code=404, detail=t.t("agent.agent_not_found"))
    if db_agent.status != "online":
        logger.warning(f"❌ Agent {agent_id} is offline")
        raise HTTPException(status_code=400, detail=t.t("agent.agent_offline"))
    if not db_agent.model_id:
        logger.warning(f"❌ Agent {agent_id} has no model configured")
        raise HTTPException(status_code=400, detail=t.t("agent.model_not_configured"))
    logger.info(f"✅ Agent {agent_id} validation passed, model_id: {db_agent.model_id}")
    try:
        from app.services.agent.rag_service import get_agent_rag_service

        user_message = chat_request.message
        formatted_user_message = _rewrite_user_message_for_format(user_message)
        rag_context = None
        rag_sources = []
        rag_service = get_agent_rag_service(db, db_agent)
        chat_factory = get_chat_factory()
        logger.info(f"🟢 Using provider-based chat routing, model_id: {db_agent.model_id}")
        model = db.query(Model).filter(Model.id == db_agent.model_id).first()
        if not model:
            logger.error(f"Model {db_agent.model_id} not found in database")
            raise HTTPException(status_code=400, detail=t.t("agent.model_not_found"))
        provider = model.provider
        if not provider:
            raise HTTPException(status_code=400, detail=t.t("agent.provider_not_found"))
        if not provider.configured:
            raise HTTPException(
                status_code=400, detail=t.t("agent.provider_not_configured", provider=provider.name)
            )
        if provider.auth_type != "none" and (not provider.get_api_key()):
            raise HTTPException(
                status_code=400, detail=t.t("agent.provider_not_configured", provider=provider.name)
            )
        conversation_id = chat_request.conversation_id
        if not conversation_id:
            conversation_id = conversation_manager.create_conversation()
            logger.info(f"Created new conversation: {conversation_id}")
        else:
            logger.info(f"Using existing conversation: {conversation_id}")
        chat_client = chat_factory.get_client(model.id, db)
        import time

        request_start_time = time.time()

        from app.services.agent.rag_service import get_agent_rag_service

        user_message = chat_request.message
        formatted_user_message = _rewrite_user_message_for_format(user_message)
        rag_context = None
        rag_sources = []
        rag_service = get_agent_rag_service(db, db_agent)
        if rag_service.should_retrieve(user_message):
            rag_start = time.time()
            logger.debug("RAG retrieval started")
            try:
                rag_result = await rag_service.retrieve_context(
                    query=user_message,
                    top_k=5,
                    threshold=0.3,
                    mode="hybrid",
                    tenant_id=current_user.tenant_id,
                    user_role=current_user.role,
                    chat_client=chat_client,
                    chat_model=model.remote_model_id or model.code,
                    use_reranker=True,
                )
                rag_context = rag_result["context_text"]
                rag_sources = rag_result["sources"]
                rag_elapsed = time.time() - rag_start
                logger.bind(elapsed=round(rag_elapsed, 2), doc_count=rag_result['retrieved_count']).debug("RAG retrieval completed")
            except Exception as e:
                logger.error(
                    f"Knowledge base retrieval failed, continuing with normal conversation: {e}"
                )
        else:
            logger.debug("Skipping RAG retrieval")
        messages = []
        enhanced_prompt = _compose_system_prompt(
            rag_service.enhance_system_prompt(
                db_agent.system_prompt or "", has_context=bool(rag_context)
            )
        )
        if enhanced_prompt:
            messages.append({"role": "system", "content": enhanced_prompt})
        history_messages = conversation_manager.get_messages(conversation_id, max_messages=10)
        messages.extend(history_messages)
        if rag_context:
            messages.append({"role": "system", "content": rag_context})
        messages.append({"role": "user", "content": formatted_user_message})
        conversation_manager.add_message(conversation_id, "user", user_message)
        user_chat_message = ChatMessage(
            agent_id=agent_id,
            user_id=current_user_id,
            conversation_id=conversation_id,
            role="user",
            content=user_message,
        )
        db.add(user_chat_message)
        db.flush()
        logger.bind(model=model.code, stream=chat_request.stream).debug("Calling chat API")
        logger.info(f"Calling chat API - Model: {model.code}, Stream: {chat_request.stream}")
        if chat_request.stream:
            agent_temperature = db_agent.temperature or 0.7
            agent_max_tokens = db_agent.max_tokens or 2048

            async def chat_stream():
                response_chunks = []
                yield _sse_event(
                    {"type": "start", "conversation_id": conversation_id, "sources": rag_sources}
                )
                try:
                    from app.services.llm.backends.chat_backend_factory import (
                        LocalAssetChatClient,
                        OpenAICompatibleChatClient,
                    )

                    if isinstance(chat_client, OpenAICompatibleChatClient | LocalAssetChatClient):
                        async for text_chunk in chat_client.async_stream_completion(
                            messages=messages,
                            temperature=agent_temperature,
                            max_tokens=agent_max_tokens,
                        ):
                            response_chunks.append(text_chunk)
                            yield _sse_event({"type": "delta", "content": text_chunk})
                        response_text = "".join(response_chunks)
                        response_text = _append_rag_reference_block(response_text, rag_sources)
                    else:
                        api_response = await chat_client.chat_completion(
                            messages=messages,
                            temperature=agent_temperature,
                            max_tokens=agent_max_tokens,
                            stream=False,
                        )
                        response_text = chat_client.extract_response_text(api_response)
                        response_text = _append_rag_reference_block(response_text, rag_sources)
                        yield _sse_event({"type": "delta", "content": response_text})
                    conversation_manager.add_message(conversation_id, "assistant", response_text)
                    format_type, table_data = _detect_response_format(response_text)
                    yield _sse_event(
                        {
                            "type": "done",
                            "conversation_id": conversation_id,
                            "response": response_text,
                            "sources": rag_sources,
                            "format": format_type,
                            "table_data": table_data,
                        }
                    )

                    async def save_to_db():
                        stream_db = SessionLocal()
                        try:
                            assistant_chat_message = ChatMessage(
                                agent_id=agent_id,
                                user_id=None,
                                conversation_id=conversation_id,
                                role="assistant",
                                content=response_text,
                                sources=rag_sources or [],
                            )
                            stream_db.add(assistant_chat_message)
                            stream_db.commit()
                            stream_db.refresh(assistant_chat_message)
                            logger.info(
                                f"✅ 助手回复已保存到数据库: agent_id={agent_id}, conversation_id={conversation_id}, message_id={assistant_chat_message.id}"
                            )
                        except Exception as e:
                            logger.error(f"Failed to save assistant message to DB: {e}")
                        finally:
                            stream_db.close()

                    asyncio.create_task(save_to_db())
                except Exception as e:
                    logger.error(f"Chat streaming failed: {str(e)}".opt(exception=True))
                    yield _sse_event({"type": "error", "error": str(e)})

            elapsed_before_stream = time.time() - request_start_time
            logger.bind(elapsed=round(elapsed_before_stream, 2)).debug("Returning StreamingResponse")
            return StreamingResponse(
                chat_stream(),
                media_type="text/event-stream",
                headers={
                    "Cache-Control": "no-cache",
                    "Connection": "keep-alive",
                    "X-Accel-Buffering": "no",
                },
            )
        logger.info("📞 Starting NON-STREAMING response")
        api_response = await chat_client.chat_completion(
            messages=messages,
            temperature=db_agent.temperature or 0.7,
            max_tokens=db_agent.max_tokens or 2048,
        )
        response_text = chat_client.extract_response_text(api_response)
        response_text = _append_rag_reference_block(response_text, rag_sources)
        usage_info = chat_client.get_usage_info(api_response)
        logger.info(f"Token usage: {usage_info}")
        conversation_manager.add_message(conversation_id, "assistant", response_text)
        assistant_chat_message = ChatMessage(
            agent_id=agent_id,
            user_id=None,
            conversation_id=conversation_id,
            role="assistant",
            content=response_text,
            sources=rag_sources or [],
        )
        db.add(assistant_chat_message)
        db.commit()
        format_type, table_data = _detect_response_format(response_text)
        return ChatResponse(
            response=response_text,
            conversation_id=conversation_id,
            sources=rag_sources,
            format=format_type,
            table_data=table_data,
        )
    except HTTPException:
        raise
    except ExternalServiceError as e:
        logger.error(f"External service error: {str(e)}")
        raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"对话失败: {str(e)}") from None
    except Exception as e:
        logger.error(f"Chat failed: {str(e)}".opt(exception=True))
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=t.t("agent.chat_failed", error=str(e)),
        ) from e


@router.websocket("/{agent_id}/chat/ws")
async def chat_websocket(websocket: WebSocket, agent_id: int, db: Session = Depends(get_db)):
    await websocket.accept()
    logger.info(f"WebSocket connection established for agent {agent_id}")
    try:
        data = await websocket.receive_json()
        user_message = data.get("message")
        conversation_id = data.get("conversation_id")
        token = data.get("token")
        if not user_message:
            await websocket.send_json({"type": "error", "error": "Message is required"})
            await websocket.close()
            return
        from app.api.deps import get_current_user_from_token

        try:
            current_user = get_current_user_from_token(token, db)
        except Exception:
            await websocket.send_json({"type": "error", "error": "Authentication failed"})
            await websocket.close()
            return
        db_agent = db.query(Agent).filter(Agent.id == agent_id).first()
        if not db_agent:
            await websocket.send_json({"type": "error", "error": "Agent not found"})
            await websocket.close()
            return
        if db_agent.status != "online":
            await websocket.send_json({"type": "error", "error": "Agent is offline"})
            await websocket.close()
            return
        if not db_agent.model_id:
            await websocket.send_json({"type": "error", "error": "Model not configured"})
            await websocket.close()
            return
        quick_actions = db_agent.quick_actions or []
        for qa in quick_actions:
            if qa.get("prompt") == user_message and qa.get("reply"):
                conv_id = conversation_id or conversation_manager.create_conversation()
                db.add(
                    ChatMessage(
                        agent_id=agent_id,
                        user_id=current_user.id,
                        conversation_id=conv_id,
                        role="user",
                        content=user_message,
                    )
                )
                db.add(
                    ChatMessage(
                        agent_id=agent_id,
                        user_id=current_user.id,
                        conversation_id=conv_id,
                        role="assistant",
                        content=qa["reply"],
                    )
                )
                db.commit()
                await websocket.send_json({"type": "content", "content": qa["reply"]})
                await websocket.send_json(
                    {
                        "type": "done",
                        "conversation_id": conv_id,
                        "response": qa["reply"],
                        "sources": [],
                    }
                )
                await websocket.close()
                return
        from app.services.agent.rag_service import get_agent_rag_service

        rag_service = get_agent_rag_service(db, db_agent)
        chat_factory = get_chat_factory()
        model = db.query(Model).filter(Model.id == db_agent.model_id).first()
        if not model:
            await websocket.send_json({"type": "error", "error": "Model not found"})
            await websocket.close()
            return
        chat_client = chat_factory.get_client(model.id, db)
        if not conversation_id:
            conversation_id = conversation_manager.create_conversation()
        formatted_user_message = _rewrite_user_message_for_format(user_message)
        rag_context = None
        rag_sources = []
        if rag_service.should_retrieve(user_message):
            try:
                rag_result = await rag_service.retrieve_context(
                    query=user_message,
                    top_k=5,
                    threshold=0.3,
                    mode="hybrid",
                    tenant_id=current_user.tenant_id,
                    user_role=current_user.role,
                    chat_client=chat_client,
                    chat_model=model.remote_model_id or model.code,
                    use_reranker=True,
                )
                rag_context = rag_result["context_text"]
                rag_sources = rag_result["sources"]
            except Exception as e:
                logger.error(f"RAG retrieval failed: {e}")
        messages = []
        enhanced_prompt = _compose_system_prompt(
            rag_service.enhance_system_prompt(
                db_agent.system_prompt or "", has_context=bool(rag_context)
            )
        )
        if enhanced_prompt:
            messages.append({"role": "system", "content": enhanced_prompt})
        history_messages = conversation_manager.get_messages(conversation_id, max_messages=10)
        messages.extend(history_messages)
        if rag_context:
            messages.append({"role": "system", "content": rag_context})
        messages.append({"role": "user", "content": formatted_user_message})
        conversation_manager.add_message(conversation_id, "user", user_message)
        user_chat_message = ChatMessage(
            agent_id=agent_id,
            user_id=current_user.id,
            conversation_id=conversation_id,
            role="user",
            content=user_message,
        )
        db.add(user_chat_message)
        db.flush()
        await websocket.send_json(
            {"type": "start", "conversation_id": conversation_id, "sources": rag_sources}
        )
        response_chunks = []
        from app.services.llm.backends.chat_backend_factory import (
            LocalAssetChatClient,
            OpenAICompatibleChatClient,
        )

        if isinstance(chat_client, OpenAICompatibleChatClient | LocalAssetChatClient):
            async for text_chunk in chat_client.async_stream_completion(
                messages=messages,
                temperature=db_agent.temperature or 0.7,
                max_tokens=db_agent.max_tokens or 2048,
            ):
                response_chunks.append(text_chunk)
                await websocket.send_json({"type": "delta", "content": text_chunk})
        else:
            api_response = await chat_client.chat_completion(
                messages=messages,
                temperature=db_agent.temperature or 0.7,
                max_tokens=db_agent.max_tokens or 2048,
                stream=False,
            )
            response_text = chat_client.extract_response_text(api_response)
            await websocket.send_json({"type": "delta", "content": response_text})
            response_chunks.append(response_text)
        response_text = "".join(response_chunks)
        response_text = _append_rag_reference_block(response_text, rag_sources)
        conversation_manager.add_message(conversation_id, "assistant", response_text)
        assistant_chat_message = ChatMessage(
            agent_id=agent_id,
            user_id=None,
            conversation_id=conversation_id,
            role="assistant",
            content=response_text,
            sources=rag_sources or [],
        )
        db.add(assistant_chat_message)
        db.commit()
        format_type, table_data = _detect_response_format(response_text, user_message)
        await websocket.send_json(
            {
                "type": "done",
                "conversation_id": conversation_id,
                "response": response_text,
                "sources": rag_sources,
                "format": format_type,
                "table_data": table_data,
            }
        )
    except WebSocketDisconnect:
        logger.info(f"WebSocket disconnected for agent {agent_id}")
    except Exception as e:
        logger.error(f"WebSocket error: {str(e)}".opt(exception=True))
        try:
            await websocket.send_json({"type": "error", "error": str(e)})
        except Exception:
            pass
    finally:
        try:
            await websocket.close()
        except Exception:
            pass
