import sqlite3
import threading
from datetime import datetime
from pathlib import Path

from common_logging import get_logger

logger = get_logger(__name__)

DEFAULT_CHECKPOINT_DIR = Path('/tmp/tax_documents/index')
CHECKPOINT_DB_NAME = 'checkpoint.db'
CHECKPOINT_INTERVAL = 50

class CheckpointManager:

    def __init__(self, task_id: str, category_id: int, db_dir: Path | None=None):
        self.task_id = task_id
        self.category_id = category_id
        self.db_dir = db_dir or DEFAULT_CHECKPOINT_DIR
        self.db_path = self.db_dir / CHECKPOINT_DB_NAME
        self._lock = threading.Lock()
        self._pending_urls: list[str] = []
        self._completed_urls: set[str] = set()
        self._ensure_db()

    def load_completed_urls(self) -> set[str]:
        with self._connect() as conn:
            rows = conn.execute('SELECT doc_url FROM checkpoints WHERE task_id = ? AND category_id = ?', (self.task_id, self.category_id)).fetchall()
        self._completed_urls = {row[0] for row in rows}
        logger.info(f'已恢复 {len(self._completed_urls)} 条断点记录')
        return self._completed_urls

    def mark_done(self, doc_url: str) -> None:
        self._completed_urls.add(doc_url)
        with self._lock:
            self._pending_urls.append(doc_url)
            if len(self._pending_urls) >= CHECKPOINT_INTERVAL:
                self._flush()

    def flush(self) -> None:
        with self._lock:
            self._flush()

    def is_done(self, doc_url: str) -> bool:
        return doc_url in self._completed_urls

    def save_page_cursor(self, page: int) -> None:
        with self._connect() as conn:
            conn.execute('\n                INSERT INTO page_cursors (task_id, category_id, last_page, updated_at)\n                VALUES (?, ?, ?, ?)\n                ON CONFLICT(task_id, category_id)\n                DO UPDATE SET last_page = excluded.last_page, updated_at = excluded.updated_at\n                ', (self.task_id, self.category_id, page, datetime.utcnow().isoformat()))

    def load_page_cursor(self) -> int:
        with self._connect() as conn:
            row = conn.execute('SELECT last_page FROM page_cursors WHERE task_id = ? AND category_id = ?', (self.task_id, self.category_id)).fetchone()
        page = row[0] if row else 0
        if page > 0:
            logger.info(f'[CheckpointManager] task={self.task_id} cat={self.category_id} 从第 {page + 1} 页继续（已完成第 1~{page} 页）')
        return page

    def clear(self) -> None:
        with self._connect() as conn:
            conn.execute('DELETE FROM checkpoints WHERE task_id = ? AND category_id = ?', (self.task_id, self.category_id))
            conn.execute('DELETE FROM page_cursors WHERE task_id = ? AND category_id = ?', (self.task_id, self.category_id))
        self._completed_urls.clear()
        logger.info(f'[CheckpointManager] task={self.task_id} cat={self.category_id} 检查点已清理')

    def stats(self) -> dict:
        return {'task_id': self.task_id, 'category_id': self.category_id, 'completed_count': len(self._completed_urls), 'pending_flush': len(self._pending_urls)}

    def _ensure_db(self) -> None:
        self.db_dir.mkdir(parents=True, exist_ok=True)
        with self._connect() as conn:
            conn.execute("\n                CREATE TABLE IF NOT EXISTS checkpoints (\n                    id          INTEGER PRIMARY KEY AUTOINCREMENT,\n                    task_id     TEXT    NOT NULL,\n                    category_id INTEGER NOT NULL,\n                    doc_url     TEXT    NOT NULL,\n                    created_at  TEXT    NOT NULL DEFAULT (datetime('now')),\n                    UNIQUE(task_id, category_id, doc_url)\n                )\n                ")
            conn.execute('\n                CREATE INDEX IF NOT EXISTS idx_cp_task_cat\n                ON checkpoints(task_id, category_id)\n                ')
            conn.execute('\n                CREATE TABLE IF NOT EXISTS page_cursors (\n                    task_id     TEXT    NOT NULL,\n                    category_id INTEGER NOT NULL,\n                    last_page   INTEGER NOT NULL DEFAULT 0,\n                    updated_at  TEXT    NOT NULL,\n                    PRIMARY KEY (task_id, category_id)\n                )\n                ')

    def _connect(self) -> sqlite3.Connection:
        conn = sqlite3.connect(str(self.db_path), timeout=10)
        conn.execute('PRAGMA journal_mode=WAL')
        conn.execute('PRAGMA synchronous=NORMAL')
        return conn

    def _flush(self) -> None:
        if not self._pending_urls:
            return
        now = datetime.utcnow().isoformat()
        rows = [(self.task_id, self.category_id, url, now) for url in self._pending_urls]
        try:
            with self._connect() as conn:
                conn.executemany('\n                    INSERT OR IGNORE INTO checkpoints\n                        (task_id, category_id, doc_url, created_at)\n                    VALUES (?, ?, ?, ?)\n                    ', rows)
            logger.debug(f'[CheckpointManager] 写入 {len(self._pending_urls)} 条断点记录')
        except Exception as e:
            logger.error(f'[CheckpointManager] 写入断点失败: {e}')
        finally:
            self._pending_urls.clear()
