# Copyright (C) 2019-2023 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

import logging
import threading
import time
import uuid
from pathlib import Path
from threading import Lock, Thread
from typing import Callable, Optional

from pymilvus.orm.schema import CollectionSchema

from .bulk_writer import BulkWriter
from .constants import (
    MB,
    BulkFileType,
)

logger = logging.getLogger("local_bulk_writer")
logger.setLevel(logging.DEBUG)


class LocalBulkWriter(BulkWriter):
    def __init__(
        self,
        schema: CollectionSchema,
        local_path: str,
        chunk_size: int = 128 * MB,
        file_type: BulkFileType = BulkFileType.PARQUET,
        **kwargs,
    ):
        super().__init__(schema, chunk_size, file_type, **kwargs)
        self._local_path = local_path
        self._uuid = str(uuid.uuid4())
        self._flush_count = 0
        self._working_thread = {}
        self._working_thread_lock = Lock()
        self._local_files = []

        self._make_dir()

    @property
    def uuid(self):
        return self._uuid

    def __enter__(self):
        return self

    def __exit__(self, exc_type: object, exc_val: object, exc_tb: object):
        self._exit()

    def __del__(self):
        self._exit()

    def _exit(self):
        # wait flush thread
        if len(self._working_thread) > 0:
            for k, th in self._working_thread.items():
                logger.info(f"Wait flush thread '{k}' to finish")
                th.join()

        self._rm_dir()

    def _make_dir(self):
        Path(self._local_path).mkdir(exist_ok=True)
        logger.info(f"Data path created: {self._local_path}")

        uidir = Path(self._local_path).joinpath(self._uuid)
        self._local_path = uidir
        Path(uidir).mkdir(exist_ok=True)
        logger.info(f"Data path created: {uidir}")

    def _rm_dir(self):
        # remove the uuid folder if it is empty
        if Path(self._local_path).exists() and not any(Path(self._local_path).iterdir()):
            Path(self._local_path).rmdir()
            logger.info(f"Delete local directory '{self._local_path}'")

    def append_row(self, row: dict, **kwargs):
        super().append_row(row, **kwargs)

        # only one thread can enter this section to persist data,
        # in the _flush() method, the buffer will be swapped to a new one.
        # in anync mode, the flush thread is asynchronously, other threads can
        # continue to append if the new buffer size is less than target size
        with self._working_thread_lock:
            if super().buffer_size > super().chunk_size:
                self.commit(_async=True)

    def commit(self, **kwargs):
        # _async=True, the flush thread is asynchronously
        while len(self._working_thread) > 0:
            logger.info(
                f"Previous flush action is not finished, {threading.current_thread().name} is waiting..."
            )
            time.sleep(1.0)

        logger.info(
            f"Prepare to flush buffer, row_count: {super().buffer_row_count}, size: {super().buffer_size}"
        )
        _async = kwargs.get("_async", False)
        call_back = kwargs.get("call_back")

        x = Thread(target=self._flush, args=(call_back,))
        logger.info(f"Flush thread begin, name: {x.name}")
        self._working_thread[x.name] = x
        x.start()
        if not _async:
            logger.info("Wait flush to finish")
            x.join()

        super().commit()  # reset the buffer size
        logger.info(f"Commit done with async={_async}")

    def _flush(self, call_back: Optional[Callable] = None):
        try:
            self._flush_count = self._flush_count + 1
            target_path = Path.joinpath(self._local_path, str(self._flush_count))

            old_buffer = super()._new_buffer()
            if old_buffer.row_count > 0:
                file_list = old_buffer.persist(
                    local_path=str(target_path),
                    buffer_size=self.buffer_size,
                    buffer_row_count=self.buffer_row_count,
                )
                self._local_files.append(file_list)
                if call_back:
                    call_back(file_list)
        except Exception as e:
            logger.error(f"Failed to fulsh, error: {e}")
            raise e from e
        finally:
            del self._working_thread[threading.current_thread().name]
            logger.info(f"Flush thread finished, name: {threading.current_thread().name}")

    @property
    def data_path(self):
        return self._local_path

    @property
    def batch_files(self):
        return self._local_files
