#include <flashinfer/attention/default_prefill_params.cuh>
#include <flashinfer/attention/default_decode_params.cuh>
#include <flashinfer/attention/variants.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/pod.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/page.cuh>


#include "pod_config.inc"

using namespace flashinfer;

namespace flashinfer {
constexpr auto use_custom_mask_p = {{ mask_mode_p }} == MaskMode::kCustom;
constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom;
// Not sure about the below declaration
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;

template cudaError_t PODWithKVCacheTensorDispatched<
    {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE,
    {{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, 16,
    {{ mask_mode_d }}, {{ variant_name_p }},
    {{ variant_name_d }}, PrefillParams, DecodeParams>(
            PrefillParams prefill_params, {{ dtype_o }}* tmp,
            DecodeParams decode_params, {{ dtype_o }}* tmp_v,
            float *tmp_s, bool enable_pdl, cudaStream_t stream);
};
