# Copyright 2021 The casbin Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re

from casbin import util, config
from . import Assertion
from .policy import Policy

DEFAULT_DOMAIN = ""
DEFAULT_SEPARATOR = "::"
PARAMS_REGEX = re.compile(r"\((.*?)\)")


class Model(Policy):
    section_name_map = {
        "r": "request_definition",
        "p": "policy_definition",
        "g": "role_definition",
        "e": "policy_effect",
        "m": "matchers",
    }

    def _load_assertion(self, cfg, sec, key):
        value = cfg.get(self.section_name_map[sec] + "::" + key)

        return self.add_def(sec, key, value)

    def get_params_token(self, value):
        """get_params_token Get params_token from Assertion.value"""
        # Find the matching string using the regular expression
        params_string = PARAMS_REGEX.search(value)

        if params_string is None:
            return []

        # Extract the captured group (inside parentheses) and split it by commas
        params_string = params_string.group(1)
        return [param.strip() for param in params_string.split(",")]

    def add_def(self, sec, key, value):
        if value == "":
            return

        ast = Assertion()
        ast.key = key
        ast.value = value

        if "r" == sec or "p" == sec:
            ast.tokens = ast.value.split(",")
            for i, token in enumerate(ast.tokens):
                ast.tokens[i] = key + "_" + token.strip()
        elif "g" == sec:
            ast.params_tokens = self.get_params_token(ast.value)
            ast.tokens = ast.value.split(",")
            ast.tokens = ast.tokens[: len(ast.tokens) - len(ast.params_tokens)]
        else:
            ast.value = util.remove_comments(util.escape_assertion(ast.value))

        if sec not in self.keys():
            self[sec] = {}

        self[sec][key] = ast

        return True

    def _get_key_suffix(self, i):
        if i == 1:
            return ""

        return str(i)

    def _load_section(self, cfg, sec):
        i = 1
        while True:
            if not self._load_assertion(cfg, sec, sec + self._get_key_suffix(i)):
                break
            else:
                i = i + 1

    def load_model(self, path):
        cfg = config.Config.new_config(path)

        self._load_section(cfg, "r")
        self._load_section(cfg, "p")
        self._load_section(cfg, "e")
        self._load_section(cfg, "m")

        self._load_section(cfg, "g")

    def load_model_from_text(self, text):
        cfg = config.Config.new_config_from_text(text)

        self._load_section(cfg, "r")
        self._load_section(cfg, "p")
        self._load_section(cfg, "e")
        self._load_section(cfg, "m")

        self._load_section(cfg, "g")

    def print_model(self):
        self.logger.info("Model:")
        for k, v in self.items():
            for i, j in v.items():
                self.logger.info("%s.%s: %s", k, i, j.value)

    def sort_policies_by_priority(self):
        for ptype, assertion in self["p"].items():
            for index, token in enumerate(assertion.tokens):
                if token == f"{ptype}_priority":
                    assertion.priority_index = index
                    break

            if assertion.priority_index == -1:
                continue

            assertion.policy = sorted(
                assertion.policy,
                key=lambda x: int(x[assertion.priority_index])
                if x[assertion.priority_index].isdigit()
                else x[assertion.priority_index],
            )

            for i, policy in enumerate(assertion.policy):
                assertion.policy_map[",".join(policy)] = i

        return None

    def sort_policies_by_subject_hierarchy(self):
        if self["e"]["e"].value != "subjectPriority(p_eft) || deny":
            return

        sub_index = 0
        domain_index = -1
        for ptype, assertion in self["p"].items():
            for index, token in enumerate(assertion.tokens):
                if token == "{}_dom".format(ptype):
                    domain_index = index
                    break

            subject_hierarchy_map = self.get_subject_hierarchy_map(self["g"]["g"].policy)

            def compare_policy(policy):
                domain = DEFAULT_DOMAIN
                if domain_index != -1:
                    domain = policy[domain_index]
                name = self.get_name_with_domain(domain, policy[sub_index])
                return subject_hierarchy_map.get(name, 0)

            assertion.policy = sorted(assertion.policy, key=compare_policy)
            for i, policy in enumerate(assertion.policy):
                assertion.policy_map[",".join(policy)] = i

    def get_subject_hierarchy_map(self, policies):
        """
        Get the subject hierarchy from the policy.
        Select the lowest level subject in multiple rounds until all subjects are selected.
        Return the subject hierarchy dictionary, the subject is the key, and the level is the value.
        The level starts from 0 and increases in turn. The smaller the level, the higher the priority.
        """
        # Init unsorted policy, and subject
        unsorted_policy = []
        unsorted_sub = set()
        for policy in policies:
            if len(policy) < 2:
                raise RuntimeError("policy g expect 2 more params")
            domain = DEFAULT_DOMAIN
            if len(policy) != 2:
                domain = policy[2]
            child = self.get_name_with_domain(domain, policy[0])
            parent = self.get_name_with_domain(domain, policy[1])
            unsorted_policy.append([child, parent])
            unsorted_sub.add(child)
            unsorted_sub.add(parent)
        # sort policy,and update sorted_sub_list
        sorted_sub_list = []
        while len(unsorted_policy) > 0:
            # get all parent subject
            parent_sub = {p[1] for p in unsorted_policy if p[1] != ""}
            # remove parent subject from unsorted_sub
            sorted_sub = unsorted_sub - parent_sub
            if not sorted_sub:
                raise RuntimeError("cycle dependency in subject hierarchy.subjects: {}".format(unsorted_sub))
            # update sorted_sub_list
            sorted_sub_list.append(sorted_sub)
            # remove sorted subject, and update unsorted_policy
            unsorted_policy = [p for p in unsorted_policy if p[0] not in sorted_sub]
            # update unsorted_sub
            unsorted_sub = unsorted_sub - sorted_sub
        if len(unsorted_sub) > 0:
            sorted_sub_list.append(unsorted_sub)
        # Tree structure of subject
        return {sub: i for i, subs in enumerate(sorted_sub_list) for sub in subs}

    def get_name_with_domain(self, domain, name):
        return "{}{}{}".format(domain, DEFAULT_SEPARATOR, name)

    def to_text(self):
        s = []

        def write_string(sec):
            for p_type in self[sec]:
                value = self[sec][p_type].value
                s.append("{} = {}\n".format(sec, value.replace("p_", "p.").replace("r_", "r.")))

        s.append("[request_definition]\n")
        write_string("r")
        s.append("[policy_definition]\n")
        write_string("p")
        if "g" in self.keys():
            s.append("[role_definition]\n")
            for p_type in self["g"]:
                s.append("{} = {}\n".format(p_type, self["g"][p_type].value))
        s.append("[policy_effect]\n")
        write_string("e")
        s.append("[matchers]\n")
        write_string("m")

        # remove last \n
        s[-1] = s[-1].strip()

        return "".join(s)

    def get_field_index(self, ptype, field):
        """get_field_index gets the index of the field for a ptype in a policy,
        return -1 if the field does not exist."""
        assertion = self["p"][ptype]
        if field in assertion.field_index_map:
            return assertion.field_index_map[field]

        pattern = f"{ptype}_{field}"
        index = -1
        for i, token in enumerate(assertion.tokens):
            if token == pattern:
                index = i
                break

        if index == -1:
            return index

        assertion.field_index_map[field] = index
        return index
