import torch

from ..generation.continuous_batching import PagedAttentionCache
from ..modeling_flash_attention_utils import lazy_import_paged_flash_attention


@torch.compiler.disable
def paged_attention_forward(
    module: torch.nn.Module,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    attention_mask: torch.Tensor | None,  # Unused in flash
    cache: PagedAttentionCache,
    cu_seq_lens_q: torch.Tensor,
    cu_seq_lens_k: torch.Tensor | dict[str, torch.Tensor],
    max_seqlen_q: int,
    max_seqlen_k: int | dict[str, int],
    block_table: torch.Tensor | None,
    **kwargs,
) -> tuple[torch.Tensor, None]:
    """Performs the forward pass of attention with paged key-value cache. This function handles the cache updates and
    performs the attention computation. For decode-only batches (when block_table is provided), uses
    `flash_attn_with_kvcache` for fused attention + cache update. Otherwise uses `flash_attn_varlen_func`.
    See the [paged attention guide](https://huggingface.co/docs/transformers/en/paged_attention) for more details.

    Args:
        q: (1, nheads, total_q, headdim), where total_q = total number of query tokens in the batch.
        k: (1, nheads_k, total_k, headdim), where total_k = total number of key tokens in the batch.
        v: (1, nheads_k, total_k, headdim), where total_k = total number of key tokens in the batch.
        cu_seq_lens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into q.
        cu_seq_lens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into kv.
        max_seqlen_q: int. Maximum query sequence length in the batch.
        max_seqlen_k: int. Maximum key sequence length in the batch.
        block_table: (num_groups, batch_size, max_blocks_per_seq), dtype int32. Block table for paged KV cache.
            If provided, uses flash_attn_with_kvcache for fused attention + cache update. For each request, the block
            table is a vector of size (max_blocks_per_seq,) with indices indicating the physical location of the cache
            to read from and write to. The kernel, using the cache_seqlens for that request, knows how much cache to
            read and dispatches the read using the block table. Same for the write. If a request has fewer than
            max_blocks_per_seq blocks, the block table is padded with -1s to indicate that the block is not allocated.
    """
    # Retrieve the flash attention functions
    flash_attn_varlen_func, flash_attn_with_kvcache = lazy_import_paged_flash_attention(
        module.config._attn_implementation
    )

    # Retrieve the cumulative sequence lengths for the current layer
    sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window - 1, 0)
    layer_type = "full_attention" if sliding_window == (-1, -1) else "sliding_attention"
    if isinstance(cu_seq_lens_k, dict):
        cu_seq_lens_k = cu_seq_lens_k[layer_type]
        max_seqlen_k = max_seqlen_k[layer_type]

    # If no block table is provided, use flash_attn_varlen_func with read/write indices
    if block_table is None:
        # .update changes the shape of k and v from [1, num_kv_heads, seqlen_kv, head_dim] to [-1, num_kv_heads, head_dim]
        k, v = cache.update(
            key_states=k,
            value_states=v,
            layer_idx=module.layer_idx,
            read_index=kwargs["read_index"],
            write_index=kwargs["write_index"],
        )
        custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {}
        attn_output = flash_attn_varlen_func(
            q.transpose(1, 2).squeeze(0).contiguous(),
            k.contiguous(),
            v.contiguous(),
            cu_seq_lens_q.to(torch.int32),
            cu_seq_lens_k.to(torch.int32).clone(),
            max_seqlen_q,
            max_seqlen_k,
            softmax_scale=module.scaling,
            causal=True,  # kind of a must, it automatically aligns the mask for q < k
            window_size=sliding_window,  # -1 means infinite context window
            **custom_kwargs,
        )
        if isinstance(attn_output, tuple):
            attn_output = attn_output[0]

    # Otherwise, use flash_attn_with_kvcache which updates the cache in-place and computes attention
    else:
        # Get layer group index for this layer
        group_idx, layer_idx_in_group = cache.layer_index_to_group_indices[module.layer_idx]
        # KV cache shape: [num_pages, num_kv_heads, head_dim] -> [num_blocks, block_size, num_kv_heads, head_dim]
        k_cache = cache.key_cache[layer_idx_in_group].view(
            -1, cache.block_size, cache.num_key_value_heads, cache.head_dim
        )
        v_cache = cache.value_cache[layer_idx_in_group].view(
            -1, cache.block_size, cache.num_key_value_heads, cache.head_dim
        )
        # Reshape Q, K, V from [1, num_*_heads, batch_size, head_dim] to [batch_size, 1, num_*_heads, head_dim]
        q = q.permute(2, 0, 1, 3).contiguous()
        k = k.permute(2, 0, 1, 3).contiguous()
        v = v.permute(2, 0, 1, 3).contiguous()
        # Compute cache_seqlens from cu_seq_lens_k (current cache length BEFORE adding new tokens)
        # cu_seq_lens_k is cumulative, so seqlens[i] = cu_seq_lens_k[i+1] - cu_seq_lens_k[i] - 1 (subtract 1 for the new token)
        batch_size = k.size(0)
        cache_seqlens = (cu_seq_lens_k[1 : batch_size + 1] - cu_seq_lens_k[:batch_size] - 1).to(torch.int32)
        # The arg name for the block table is not the same in VLLM's kernel and Tri Dao's kernel, so we need to parse it
        flash_kwargs = {cache.get_block_table_key(flash_attn_with_kvcache): block_table[group_idx]}
        if "s_aux" in kwargs:
            flash_kwargs["s_aux"] = kwargs["s_aux"]  # this is only available in VLLM's FA3
        # Call flash_attn_with_kvcache - this updates cache in-place and computes attention
        attn_output = flash_attn_with_kvcache(
            q=q,
            k_cache=k_cache,
            v_cache=v_cache,
            k=k,
            v=v,
            cache_seqlens=cache_seqlens,
            softmax_scale=module.scaling,
            causal=True,
            window_size=sliding_window,
            **flash_kwargs,
        )
        if isinstance(attn_output, tuple):
            attn_output = attn_output[0]
        # Reshape output from [batch_size, 1, num_heads, head_dim] to [batch_size, num_heads, head_dim]
        attn_output = attn_output.squeeze(1)
    return attn_output, None
