# Copyright © 2023-2024 Apple Inc.

import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Optional

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention


@dataclass
class ModelArgs(BaseModelArgs):
    model_type: str
    hidden_size: int
    dense_attention_every_n_layers: int
    ff_intermediate_size: int
    gegelu_limit: float
    num_hidden_layers: int
    num_attention_heads: int
    layer_norm_epsilon: float
    vocab_size: int
    num_key_value_heads: int
    mup_attn_multiplier: float = 1.0
    mup_use_scaling: bool = True
    mup_embedding_multiplier: float = 10.0
    mup_width_multiplier: float = 8.0
    rope_embedding_base: float = 1000000
    rope_position_scale: float = 1.0
    blocksparse_block_size: int = 64
    blocksparse_num_local_blocks: int = 16
    blocksparse_vert_stride: int = 8


@partial(mx.compile, shapeless=True)
def gegelu_impl(a_gelu, a_linear, limit):
    a_gelu = mx.where(
        mx.isinf(a_gelu),
        a_gelu,
        mx.clip(a_gelu, a_min=None, a_max=limit),
    )
    a_linear = mx.where(
        mx.isinf(a_linear),
        a_linear,
        mx.clip(a_linear, a_min=-limit, a_max=limit),
    )
    out_gelu = a_gelu * mx.sigmoid(1.702 * a_gelu)
    return out_gelu * (a_linear + 1.0)


def gegelu(x, limit):
    a_gelu, a_linear = x[..., ::2], x[..., 1::2]
    return gegelu_impl(a_gelu, a_linear, limit)


class Attention(nn.Module):
    def __init__(self, args: ModelArgs, layer_idx):
        super().__init__()

        dim = args.hidden_size
        self.n_heads = n_heads = args.num_attention_heads
        self.n_kv_heads = n_kv_heads = args.num_key_value_heads
        self.n_q_per_kv = n_heads // n_kv_heads

        self.head_dim = head_dim = args.hidden_size // n_heads

        self.query_key_value = nn.Linear(
            dim, (self.n_heads + 2 * self.n_kv_heads) * head_dim
        )
        self.dense = nn.Linear(dim, dim)

        if args.mup_use_scaling:
            norm_factor = head_dim / args.mup_attn_multiplier
        else:
            norm_factor = math.sqrt(head_dim)
        self.scale = 1.0 / norm_factor

        self.rope = nn.RoPE(
            head_dim,
            traditional=False,
            base=args.rope_embedding_base,
            scale=args.rope_position_scale,
        )

        if layer_idx % args.dense_attention_every_n_layers == 0:
            self.block_sparse = True
            self.blocksparse_block_size = args.blocksparse_block_size
            if self.blocksparse_block_size not in (32, 64):
                raise ValueError(
                    f"Unsupported block size {self.blocksparse_block_size}"
                )
            self.blocksparse_num_local_blocks = args.blocksparse_num_local_blocks
            self.blocksparse_vert_stride = args.blocksparse_vert_stride
        else:
            self.block_sparse = False

    def _block_sparse_mask(self, q_len, kv_len):
        vert_stride = self.blocksparse_vert_stride
        local_blocks = self.blocksparse_num_local_blocks
        block_size = self.blocksparse_block_size
        n_heads = self.n_heads

        kv_blocks = (kv_len + block_size - 1) // block_size
        q_blocks = (q_len + block_size - 1) // block_size
        q_pos = mx.arange(kv_blocks - q_blocks, kv_blocks)[None, :, None]
        k_pos = mx.arange(kv_blocks)[None, None]

        mask_vert_strided = (
            mx.arange(kv_blocks)[None, :] + mx.arange(1, n_heads + 1)[:, None]
        ) % vert_stride
        mask_vert_strided = (mask_vert_strided == 0)[:, None, :]

        block_mask = (q_pos >= k_pos) & (
            (q_pos - k_pos < local_blocks) | mask_vert_strided
        )
        block_mask = block_mask.reshape(
            self.n_kv_heads, self.n_q_per_kv, *block_mask.shape[-2:]
        )
        dense_mask = mx.repeat(
            mx.repeat(block_mask, block_size, axis=-1), block_size, axis=-2
        )
        return block_mask, dense_mask[..., -q_len:, :kv_len]

    def _block_sparse_attention(self, queries, keys, values, scale, mask):
        queries = scale * queries
        B = queries.shape[0]
        L = queries.shape[2]
        queries = mx.reshape(queries, (B, self.n_kv_heads, self.n_q_per_kv, L, -1))
        keys = mx.expand_dims(keys, 2)
        values = mx.expand_dims(values, 2)

        # TODO get rid of dense mask if we have a fill value
        block_mask, dense_mask = self._block_sparse_mask(L, keys.shape[-2])
        scores = queries @ mx.swapaxes(keys, -1, -2)
        # TODO, uncomment when faster
        # scores = mx.block_masked_mm(
        #   queries,
        #   mx.swapaxes(keys, -1, -2),
        #   mask_out=block_mask,
        #   block_size=self.blocksparse_block_size,
        # )

        if mask is not None:
            scores = scores + mask
        scores = scores + mx.where(
            dense_mask, mx.array(0, scores.dtype), mx.array(-float("inf"), scores.dtype)
        )
        scores = mx.softmax(scores, axis=-1, precise=True)

        output = scores @ values
        # TODO, uncomment when faster
        # output = mx.block_masked_mm(
        #    scores, values, mask_lhs=block_mask, block_size=self.blocksparse_block_size
        # )
        return mx.reshape(output, (B, self.n_heads, L, -1))

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Any] = None,
    ) -> mx.array:
        B, L, D = x.shape

        qkv = self.query_key_value(x)
        qkv = qkv.reshape(B, L, -1, self.n_q_per_kv + 2, self.head_dim)
        queries = qkv[..., :-2, :].flatten(-3, -2)
        keys = qkv[..., -2, :]
        values = qkv[..., -1, :]

        # Prepare the queries, keys and values for the attention computation
        queries = queries.transpose(0, 2, 1, 3)
        keys = keys.transpose(0, 2, 1, 3)
        values = values.transpose(0, 2, 1, 3)

        if cache is not None:
            queries = self.rope(queries, offset=cache.offset)
            keys = self.rope(keys, offset=cache.offset)
            keys, values = cache.update_and_fetch(keys, values)
        else:
            queries = self.rope(queries)
            keys = self.rope(keys)

        if self.block_sparse:
            output = self._block_sparse_attention(
                queries, keys, values, scale=self.scale, mask=mask
            )
        else:
            output = scaled_dot_product_attention(
                queries, keys, values, cache=cache, scale=self.scale, mask=mask
            )
        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
        return self.dense(output)


class MLP(nn.Module):
    def __init__(self, args):
        super().__init__()
        dim = args.hidden_size
        hidden_dim = args.ff_intermediate_size
        self.gegelu_limit = args.gegelu_limit
        self.up_proj = nn.Linear(dim, 2 * hidden_dim)
        self.down_proj = nn.Linear(hidden_dim, dim)

    def __call__(self, x) -> mx.array:
        x = self.up_proj(x)
        return self.down_proj(gegelu(x, self.gegelu_limit))


class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs, layer_idx):
        super().__init__()
        self.num_attention_heads = args.num_attention_heads
        self.hidden_size = args.hidden_size
        self.self_attn = Attention(args, layer_idx)
        self.mlp = MLP(args)
        self.input_layernorm = nn.LayerNorm(
            args.hidden_size, eps=args.layer_norm_epsilon
        )
        self.post_attention_layernorm = nn.LayerNorm(
            args.hidden_size,
            eps=args.layer_norm_epsilon,
        )
        self.args = args

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Any] = None,
    ) -> mx.array:
        r = self.self_attn(self.input_layernorm(x), mask, cache)
        h = x + r
        r = self.mlp(self.post_attention_layernorm(h))
        out = h + r
        return out


class Phi3Model(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.vocab_size = args.vocab_size
        self.num_hidden_layers = args.num_hidden_layers
        assert self.vocab_size > 0
        self.mup_embedding_multiplier = args.mup_embedding_multiplier
        self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
        self.layers = [
            TransformerBlock(args=args, layer_idx=l)
            for l in range(args.num_hidden_layers)
        ]
        self.final_layernorm = nn.LayerNorm(
            args.hidden_size, eps=args.layer_norm_epsilon
        )

    def __call__(
        self,
        inputs: mx.array,
        cache=None,
    ):
        h = self.embed_tokens(inputs)
        if self.mup_embedding_multiplier:
            h = self.mup_embedding_multiplier * h

        if cache is None:
            cache = [None] * len(self.layers)

        mask = create_attention_mask(h, cache[0], return_array=True)

        for layer, c in zip(self.layers, cache):
            h = layer(h, mask, c)

        return self.final_layernorm(h)


class Model(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.model_type = args.model_type
        self.model = Phi3Model(args)
        self.args = args
        self.mup_width_multiplier = args.mup_width_multiplier
        self._dummy_tokenizer_ids = mx.array(
            [100256, 100258, 100259, 100260, 100264, 100265]
            + list(range(100267, 100352))
        )

    def __call__(
        self,
        inputs: mx.array,
        cache=None,
    ):
        out = self.model(inputs, cache)
        out = self.model.embed_tokens.as_linear(out)
        if self.mup_width_multiplier:
            out = out / self.mup_width_multiplier
        out[self._dummy_tokenizer_ids] = -float("inf")
        return out

    @property
    def layers(self):
        return self.model.layers

    def sanitize(self, weights):
        # Remove unused precomputed rotary freqs
        return {
            k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
        }
