#include <flashinfer/attention/default_prefill_params.cuh>
#include <flashinfer/attention/default_decode_params.cuh>
#include <flashinfer/attention/variants.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/batch_pod.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/page.cuh>

#include "batch_pod_config.inc"

using namespace flashinfer;

namespace flashinfer {
constexpr auto use_custom_mask_p = {{ mask_mode_p }} == MaskMode::kCustom;
constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom;
// Not sure about the below declaration
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;

{% for cta_tile_q in [16, 64, 128] %}
template cudaError_t BatchPODWithKVCacheTensorDispatched<
    {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE,
    {{ use_fp16_qk_reduction }}, /*CTA_TILE_Q_P=*/{{cta_tile_q}}, {{ mask_mode_p }},
    /*CTA_TILE_Q_D=*/16, {{ mask_mode_d }}, {{ variant_name_p }},
    {{ variant_name_d }}, PrefillParams, DecodeParams>(
            PrefillParams prefill_params, {{ dtype_o }}* tmp_v_p, float *tmp_s_p,
            DecodeParams decode_params, {{ dtype_o }}* tmp_v_d, float *tmp_s_d,
            bool enable_pdl, cudaStream_t stream, int* sm_aware_sched);
{% endfor %}
}
