from collections.abc import Sequence
from typing import Annotated

from numpy.typing import NDArray

import mlx.core


from mlx.core import array, Dtype, Device, Stream, scalar
from typing import Sequence, Optional, Union

def rms_norm(x: array, weight: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array:
    """
    Root Mean Square normalization (RMS norm).

    The normalization is with respect to the last axis of the input ``x``.

    Args:
        x (array): Input array.
        weight (array, optional): A multiplicative weight to scale the result by.
          The ``weight`` should be one-dimensional with the same size
          as the last axis of ``x``. If set to ``None`` then no scaling happens.
        eps (float): A small additive constant for numerical stability.

    Returns:
        array: The output array.
    """

def layer_norm(x: array, weight: Optional[array], bias: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array:
    """
    Layer normalization.

    The normalization is with respect to the last axis of the input ``x``.

    Args:
        x (array): Input array.
        weight (array, optional): A multiplicative weight to scale the result by.
          The ``weight`` should be one-dimensional with the same size
          as the last axis of ``x``. If set to ``None`` then no scaling happens.
        bias (array, optional): An additive offset to be added to the result.
          The ``bias`` should be one-dimensional with the same size
          as the last axis of ``x``. If set to ``None`` then no translation happens.
        eps (float): A small additive constant for numerical stability.

    Returns:
        array: The output array.
    """

def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: Union[int, array], freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array:
    """
    Apply rotary positional encoding to the input.

    The input is expected to be at least 3D with shape ``(B, *, T, D)`` where:
      * ``B`` is the batch size.
      * ``T`` is the sequence length.
      * ``D`` is the feature dimension.

    Args:
        a (array): The input array.
        dims (int): The feature dimensions to be rotated. If the input feature
          is larger than dims then the rest is left unchanged.
        traditional (bool): If set to ``True`` choose the traditional
          implementation which rotates consecutive dimensions.
        base (float, optional): The base used to compute angular frequency for
          each dimension in the positional encodings. Exactly one of ``base`` and
          ``freqs`` must be ``None``.
        scale (float): The scale used to scale the positions.
        offset (int or array): The position offset to start at. If an
          :obj:`array` is given it can be a scalar or vector of ``B``
          offsets for each example in the batch.
        freqs (array, optional): Optional frequencies to use with RoPE.
          If set, the ``base`` parameter must be ``None``. Default: ``None``.

    Returns:
        array: The output array.
    """

def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float,  mask: Union[None, str, array] = None, sinks: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array:
    """
    A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.

    Supports:

    * `Multi-Head Attention <https://arxiv.org/abs/1706.03762>`_
    * `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_
    * `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_

    .. note::

      * The softmax operation is performed in ``float32`` regardless of
        the input precision.
      * For Grouped Query Attention and Multi-Query Attention, the ``k``
        and ``v`` inputs should not be pre-tiled to match ``q``.

    In the following the dimensions are given by:

    * ``B``: The batch size.
    * ``N_q``: The number of query heads.
    * ``N_kv``: The number of key and value heads.
    * ``T_q``: The number of queries per example.
    * ``T_kv``: The number of keys and values per example.
    * ``D``: The per-head dimension.

    Args:
        q (array): Queries with shape ``[B, N_q, T_q, D]``.
        k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
        v (array): Values with shape ``[B, N_kv, T_kv, D]``.
        scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``).
        mask (str or array, optional): The mask to apply to the
           query-key scores. The mask can be an array or a string indicating
           the mask type. The only supported string type is ``"causal"``. If
           the mask is an array it can be a boolean or additive mask. The mask
           can have at most 4 dimensions and must be broadcast-compatible with
           the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its
           type must promote to the promoted type of ``q``, ``k``, and ``v``.
           The ``"causal"`` mask uses lower-right alignment where the
           last query aligns with the last key.
        sinks (array, optional): An optional array of attention sinks.
           Default: ``None``.

    Returns:
        array: The output array.

    Example:

      .. code-block:: python

        B = 2
        N_q = N_kv = 32
        T_q = T_kv = 1000
        D = 128

        q = mx.random.normal(shape=(B, N_q, T_q, D))
        k = mx.random.normal(shape=(B, N_kv, T_kv, D))
        v = mx.random.normal(shape=(B, N_kv, T_kv, D))
        scale = D ** -0.5
        out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask="causal")
    """

def metal_kernel(name: str, input_names: Sequence[str], output_names: Sequence[str], source: str, header: str = '', ensure_row_contiguous: bool = True, atomic_outputs: bool = False) -> object:
    r"""
    A jit-compiled custom Metal kernel defined from a source string.

    Full documentation: :ref:`custom_metal_kernels`.

    Args:
      name (str): Name for the kernel.
      input_names (List[str]): The parameter names of the inputs in the
         function signature.
      output_names (List[str]): The parameter names of the outputs in the
         function signature.
      source (str): Source code. This is the body of a function in Metal,
         the function signature will be automatically generated.
      header (str): Header source code to include before the main function.
         Useful for helper functions or includes that should live outside of
         the main function body.
      ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
         before the kernel runs. Default: ``True``.
      atomic_outputs (bool): Whether to use atomic outputs in the function signature
         e.g. ``device atomic<float>``. Default: ``False``.

    Returns:
      Callable ``metal_kernel``.

    Example:

      .. code-block:: python

        def exp_elementwise(a: mx.array):
            source = \'\''
                uint elem = thread_position_in_grid.x;
                T tmp = inp[elem];
                out[elem] = metal::exp(tmp);
            \'\''

            kernel = mx.fast.metal_kernel(
                name="myexp",
                input_names=["inp"],
                output_names=["out"],
                source=source
            )
            outputs = kernel(
                inputs=[a],
                template=[("T", mx.float32)],
                grid=(a.size, 1, 1),
                threadgroup=(256, 1, 1),
                output_shapes=[a.shape],
                output_dtypes=[a.dtype],
                verbose=True,
            )
            return outputs[0]

        a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
        b = exp_elementwise(a)
        assert mx.allclose(b, mx.exp(a))
    """

def cuda_kernel(name: str, input_names: Sequence[str], output_names: Sequence[str], source: str, header: str = '', ensure_row_contiguous: bool = True, shared_memory: int = 0) -> object:
    r"""
    A jit-compiled custom CUDA kernel defined from a source string.

    This is the CUDA equivalent of :ref:`custom_metal_kernels`.

    Args:
      name (str): Name for the kernel.
      input_names (List[str]): The parameter names of the inputs in the
         function signature.
      output_names (List[str]): The parameter names of the outputs in the
         function signature.
      source (str): Source code. This is the body of a function in CUDA,
         the function signature will be automatically generated.
      header (str): Header source code to include before the main function.
         Useful for helper functions or includes that should live outside of
         the main function body.
      ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
         before the kernel runs. Default: ``True``.
      shared_memory (int): The dynamic shared memory to request for the
        kernel. A value of 0 means no dynamic shared memory. Default: ``0``.

    Returns:
      Callable ``cuda_kernel``.

    Example:

      .. code-block:: python

        def exp_elementwise(a: mx.array):
            source = \'\''
                auto elem = cooperative_groups::this_grid().thread_rank();
                T tmp = inp[elem];
                out[elem] = exp(tmp);
            \'\''

            kernel = mx.fast.cuda_kernel(
                name="myexp",
                input_names=["inp"],
                output_names=["out"],
                source=source
            )
            outputs = kernel(
                inputs=[a],
                template=[("T", mx.float32)],
                grid=(a.size, 1, 1),
                threadgroup=(256, 1, 1),
                output_shapes=[a.shape],
                output_dtypes=[a.dtype],
                verbose=True,
            )
            return outputs[0]

        a = mx.random.normal(shape=(16, 16)).astype(mx.float16)
        b = exp_elementwise(a)
        assert mx.allclose(b, mx.exp(a))
    """

def precompiled_cuda_kernel(*, name: str, compiled_source: bytes, inputs: Sequence[bool | int | float | mlx.core.array | Annotated[NDArray, dict(order='C', device='cpu', writable=False)] | complex | mlx.core.ArrayLike], output_shapes: Sequence[tuple[int, ...]], output_dtypes: Sequence[mlx.core.Dtype], scalars: Sequence[object], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], shared_memory: int = 0, init_value: float | None = None, ensure_row_contiguous: bool = False, stream: mlx.core.Stream | mlx.core.Device | None = None) -> list[mlx.core.array]:
    """
    Run a precompiled CUDA kernel defined from PTX or cubin.

    This op is still experimental and various parts of the API may change.

    Args:
      name (str): Name for the kernel
      compiled_source (bytes): The precompiled kernel in raw bytes.
      inputs (List[array]): The inputs passed to the CUDA kernel.
      output_shapes (List[Sequence[int]]): The list of shapes for each output.
      output_dtypes (List[Dtype]): The list of data types for each output.
      scalars (List[Union[bool, int, float]]): A list of scalar arguments to
        pass to the kernel.
      grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
        For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks.
      threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
      shared_memory (int): The dynamic shared memory to request for the
        kernel. A value of 0 means no dynamic shared memory. Default: ``0``.
      init_value (float, optional): Optional value to use to initialize all of the output arrays.
          By default, output arrays are uninitialized. Default: ``None``.
      ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
         before the kernel runs. Default: ``False``.
      stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
    """
