
import httpx
from common_logging import get_logger

logger = get_logger(__name__)

_ocr = None

def _get_ocr():
    global _ocr
    if _ocr is None:
        try:
            from paddleocr import PaddleOCR
            _ocr = PaddleOCR(lang='ch', use_angle_cls=False)
        except ImportError:
            logger.error('[ImageOCR] paddleocr 未安装，请执行: pip install paddlepaddle paddleocr')
            raise
    return _ocr

def ocr_image_bytes(image_bytes: bytes, min_confidence: float=0.6) -> str:
    if not image_bytes:
        return ''
    try:
        import cv2
        import numpy as np
        nparr = np.frombuffer(image_bytes, np.uint8)
        img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
        if img is None:
            logger.warning('[ImageOCR] 无法解码图片（可能格式不支持）')
            return ''
        ocr = _get_ocr()
        try:
            result = ocr.predict(img)
        except TypeError:
            result = ocr.ocr(img)
        return _result_to_html(result, min_confidence)
    except Exception as e:
        logger.error(f'[ImageOCR] ocr_image_bytes 失败: {e}')
        return ''

async def ocr_image_url(url: str, headers: dict | None=None, min_confidence: float=0.6) -> str:
    try:
        async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
            resp = await client.get(url, headers=headers or {})
            resp.raise_for_status()
            image_bytes = resp.content
        return ocr_image_bytes(image_bytes, min_confidence)
    except Exception as e:
        logger.warning(f'[ImageOCR] 下载或 OCR 失败: {url} — {e}')
        return ''

def _result_to_html(result, min_confidence: float) -> str:
    if not result:
        return ''
    items = _extract_items(result, min_confidence)
    if not items:
        return ''
    row_items = _cluster_row_items(items)
    if _looks_like_table(list(row_items)):
        matrix = _align_rows_to_columns(row_items)
        return _rows_to_html_table(matrix)
    lines = [text for text, *_ in items]
    return '<p>' + '<br>'.join(lines) + '</p>'

def _extract_items(result, min_confidence: float) -> list[tuple]:
    items = []
    page = result[0]
    if isinstance(page, dict):
        texts = page.get('rec_texts', [])
        scores = page.get('rec_scores', [])
        boxes = page.get('dt_polys', page.get('rec_boxes', [None] * len(texts)))
        for text, score, box in zip(texts, scores, boxes, strict=False):
            if not text or not text.strip() or score < min_confidence:
                continue
            x1, y1, x2, y2 = _box_to_xyxy(box)
            items.append((text.strip(), x1, y1, x2, y2))
    elif isinstance(page, list):
        for item in page:
            if not item or len(item) < 2:
                continue
            box, text_info = (item[0], item[1])
            if isinstance(text_info, list | tuple) and len(text_info) >= 2:
                text, score = (text_info[0], text_info[1])
            else:
                text, score = (str(text_info), 1.0)
            if not text or not text.strip() or score < min_confidence:
                continue
            x1, y1, x2, y2 = _box_to_xyxy(box)
            items.append((text.strip(), x1, y1, x2, y2))
    items.sort(key=lambda t: (t[2], t[1]))
    return items

def _box_to_xyxy(box) -> tuple[float, float, float, float]:
    if box is None:
        return (0.0, 0.0, 0.0, 0.0)
    try:
        import numpy as np
        arr = np.array(box, dtype=float).reshape(-1, 2)
        xs = arr[:, 0]
        ys = arr[:, 1]
        return (float(xs.min()), float(ys.min()), float(xs.max()), float(ys.max()))
    except Exception:
        flat = []
        try:
            for v in box:
                try:
                    flat.extend(float(x) for x in v)
                except TypeError:
                    flat.append(float(v))
            if len(flat) >= 4:
                xs = flat[0::2]
                ys = flat[1::2]
                return (min(xs), min(ys), max(xs), max(ys))
        except Exception:
            pass
        return (0.0, 0.0, 0.0, 0.0)

def _cluster_row_items(items: list[tuple], row_tolerance_ratio: float=0.6) -> list[list[tuple]]:
    if not items:
        return []
    rows: list[list[tuple]] = []
    current_row: list[tuple] = [items[0]]
    for item in items[1:]:
        _, x1, y1, x2, y2 = item
        cy = (y1 + y2) / 2
        h = max(y2 - y1, 1)
        tol = h * row_tolerance_ratio
        row_cy = sum((t[2] + t[4]) / 2 for t in current_row) / len(current_row)
        if abs(cy - row_cy) <= tol:
            current_row.append(item)
        else:
            current_row.sort(key=lambda t: t[1])
            rows.append(current_row)
            current_row = [item]
    if current_row:
        current_row.sort(key=lambda t: t[1])
        rows.append(current_row)
    return rows

def _looks_like_table(rows: list[list], min_rows: int=3) -> bool:
    if len(rows) < min_rows:
        return False
    multi = sum(1 for r in rows if len(r) >= 2)
    return multi >= max(1, len(rows) * 0.3)

def _align_rows_to_columns(row_items: list[list[tuple]]) -> list[list[str]]:
    if not row_items:
        return []
    from collections import Counter
    multi_rows = [row for row in row_items if len(row) >= 2]
    if not multi_rows:
        return [[row[0][0]] if row else [''] for row in row_items]
    length_counts = Counter(len(r) for r in multi_rows)
    mode_cols = length_counts.most_common(1)[0][0]
    max_cols = max(length_counts.keys())
    n_cols = max_cols if max_cols - mode_cols <= 1 and length_counts[max_cols] >= 2 else mode_cols
    template_rows = [row for row in multi_rows if len(row) == n_cols]
    col_xcenters: list[list[float]] = [[] for _ in range(n_cols)]
    for row in template_rows:
        for k, item in enumerate(row):
            _, x1, y1, x2, y2 = item
            col_xcenters[k].append((x1 + x2) / 2)

    def _median(vals: list[float]) -> float:
        s = sorted(vals)
        n = len(s)
        if not n:
            return 0.0
        return s[n // 2] if n % 2 else (s[n // 2 - 1] + s[n // 2]) / 2.0
    col_centers = [_median(xcs) for xcs in col_xcenters]

    def nearest_col(xc: float) -> int:
        return min(range(n_cols), key=lambda i: abs(col_centers[i] - xc))
    matrix: list[list[str]] = []
    for row in row_items:
        cells: list[str] = [''] * n_cols
        for item in row:
            _, x1, y1, x2, y2 = item
            xc = (x1 + x2) / 2
            col_idx = nearest_col(xc)
            if cells[col_idx]:
                cells[col_idx] += ' ' + item[0]
            else:
                cells[col_idx] = item[0]
        matrix.append(cells)
    non_empty_cols = [c for c in range(n_cols) if any(row[c] for row in matrix)]
    if not non_empty_cols:
        return [[t[0] for t in row] for row in row_items]
    return [[row[c] for c in non_empty_cols] for row in matrix]

def _rows_to_html_table(rows: list[list[str]]) -> str:
    if not rows:
        return ''
    html_rows = []
    for i, row in enumerate(rows):
        tag = 'th' if i == 0 else 'td'
        cells = ''.join(f'<{tag}>{c}</{tag}>' for c in row)
        html_rows.append(f'<tr>{cells}</tr>')
    return "<table border='1' style='border-collapse:collapse;width:100%'><tbody>" + ''.join(html_rows) + '</tbody></table>'
