import json
from pathlib import Path

from common_logging import get_logger

logger = get_logger(__name__)
STATE_FILE = Path(__file__).parent / '.import_state.json'

class ImportState:

    def __init__(self):
        self._data = self._load()

    def _load(self) -> dict:
        if STATE_FILE.exists():
            try:
                return json.loads(STATE_FILE.read_text())
            except Exception as e:
                logger.warning("Failed to load import state file, starting fresh", error=str(e))
        return {}

    def _save(self):
        STATE_FILE.write_text(json.dumps(self._data, indent=2, ensure_ascii=False))

    def reset(self):
        self._data = {}
        if STATE_FILE.exists():
            STATE_FILE.unlink()

    def get_kb_id(self) -> int | None:
        return self._data.get('kb_id')

    def set_kb_id(self, kb_id: int):
        self._data['kb_id'] = kb_id
        self._save()

    def get_category_map(self) -> dict[int, int]:
        raw = self._data.get('category_map', {})
        return {int(k): v for k, v in raw.items()}

    def set_category_map(self, mapping: dict[int, int]):
        self._data['category_map'] = {str(k): v for k, v in mapping.items()}
        self._save()

    def is_phase_done(self, phase: str) -> bool:
        return self._data.get(f'phase_{phase}_done', False)

    def mark_phase_done(self, phase: str):
        self._data[f'phase_{phase}_done'] = True
        self._save()

    def get_phase2_cursor(self) -> int:
        return self._data.get('phase2_cursor', 0)

    def set_phase2_cursor(self, doc_id: int):
        self._data['phase2_cursor'] = doc_id
        self._save()

    def get_phase3_cursor(self) -> int:
        return self._data.get('phase3_cursor', 0)

    def set_phase3_cursor(self, knowledge_doc_id: int):
        self._data['phase3_cursor'] = knowledge_doc_id
        self._save()

    def get_phase4_cursor(self) -> int:
        return self._data.get('phase4_cursor', 0)

    def set_phase4_cursor(self, knowledge_doc_id: int):
        self._data['phase4_cursor'] = knowledge_doc_id
        self._save()

    def set_source_category(self, source_id: int, kb_cat_id: int):
        if 'source_category_map' not in self._data:
            self._data['source_category_map'] = {}
        self._data['source_category_map'][str(source_id)] = kb_cat_id
        self._save()

    def get_source_category_map(self) -> dict:
        return {int(k): v for k, v in self._data.get('source_category_map', {}).items()}
