
    )jg                        d dl mZ d dlmZ d Zd Zd Z e            Z e            Z	ej
        d             Zej        d             Zd Zd Zd	 Z e            Z e            Zej
        d
             Zej        d             Zd ZdS )    Nc                      t          j                    t           j        k    ot           j                                        S )N)mxdefault_devicegpumetalis_available     ]/lsinfo/ai/hellotax_ai/base_platform/venv/lib/python3.11/site-packages/mlx_lm/tuner/losses.pycan_run_metalr      s+    "&(DRX-B-B-D-DDr
   c                  t    t                      sd S d} t          j                            dddgdg| d          S )Naz  
    constexpr int M = 4;
    constexpr int block = 1024 * M;
    constexpr int full_blocks = V / block;
    constexpr int extra = V - full_blocks * block;

    threadgroup float shared[32 * 2];

    uint out_idx = threadgroup_position_in_grid.y;
    uint simd_lane_id = thread_index_in_simdgroup;
    uint simd_group_id = simdgroup_index_in_threadgroup;

    logits_q += out_idx * V;
    logits_p += out_idx * V;
    out += out_idx;

    float lse_q_minus_p;
    float lse_p;

    {
        float max_q = -1e30;
        float max_p = -1e30;
        float sum_exp_q = 0;
        float sum_exp_p = 0;

        int offset = thread_index_in_threadgroup * M;
        for (int i = 0; i < full_blocks; i++) {
            // Read and update q and p
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j<M; j++) {
                vals_q[j] = logits_q[offset + j];
                vals_p[j] = logits_p[offset + j];
            }
            float prev_max_q = max_q;
            float prev_max_p = max_p;
            for (int j=0; j<M; j++) {
                max_q = max(max_q, vals_q[j]);
                max_p = max(max_p, vals_p[j]);
            }
            sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
            sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
            for (int j=0; j<M; j++) {
                sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
                sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
            }

            // Move to the next block
            offset += block;
        }
        if (extra > 0) {
            // Read and update q and p
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j < M; j++) {
                vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
                vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
            }
            float prev_max_q = max_q;
            float prev_max_p = max_p;
            for (int j=0; j<M; j++) {
                max_q = max(max_q, vals_q[j]);
                max_p = max(max_p, vals_p[j]);
            }
            sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
            sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
            for (int j=0; j<M; j++) {
                sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
                sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
            }
        }

        // Share the maxs across the threadgroup
        float prev_max_q = max_q;
        float prev_max_p = max_p;
        max_q = simd_max(max_q);
        max_p = simd_max(max_p);
        if (simd_lane_id == 0) {
            shared[simd_group_id * 2 + 0] = max_q;
            shared[simd_group_id * 2 + 1] = max_p;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        max_q = shared[simd_lane_id * 2 + 0];
        max_p = shared[simd_lane_id * 2 + 1];
        max_q = simd_max(max_q);
        max_p = simd_max(max_p);

        // Share the sum_exp across the threadgroup
        sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
        sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
        sum_exp_q = simd_sum(sum_exp_q);
        sum_exp_p = simd_sum(sum_exp_p);
        if (simd_lane_id == 0) {
            shared[simd_group_id * 2 + 0] = sum_exp_q;
            shared[simd_group_id * 2 + 1] = sum_exp_p;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        sum_exp_q = shared[simd_lane_id * 2 + 0];
        sum_exp_p = shared[simd_lane_id * 2 + 1];
        sum_exp_q = simd_sum(sum_exp_q);
        sum_exp_p = simd_sum(sum_exp_p);

        lse_p = max_p + metal::fast::log(sum_exp_p);
        lse_q_minus_p = max_q + metal::fast::log(sum_exp_q) - lse_p;
    }

    threadgroup_barrier(mem_flags::mem_none);

    {
        float kl = 0;

        int offset = thread_index_in_threadgroup * M;
        for (int i = 0; i < full_blocks; i++) {
            // Read and add to the kl
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j<M; j++) {
                vals_q[j] = logits_q[offset + j];
                vals_p[j] = logits_p[offset + j];
            }

            for (int j=0; j<M; j++) {
                kl += metal::fast::exp(vals_p[j] - lse_p) * (vals_p[j] - vals_q[j] + lse_q_minus_p);
            }

            // Move to the next block
            offset += block;
        }
        if (extra > 0) {
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j<M; j++) {
                vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
                vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
            }

            for (int j=0; j<M; j++) {
                kl += metal::fast::exp(vals_p[j] - lse_p) * (vals_p[j] - vals_q[j] + lse_q_minus_p);
            }
        }

        // Add the kl across the threadgroup
        kl = simd_sum(kl);
        if (simd_lane_id == 0) {
            shared[simd_group_id] = kl;
        }
        threadgroup_barrier(mem_flags::mem_none);
        kl = shared[simd_lane_id];
        kl = simd_sum(kl);

        if (thread_index_in_threadgroup == 0) {
            out[0] = static_cast<T>(kl);
        }
    }
    
kl_forwardlogits_qlogits_poutTnameinput_namesoutput_namessourceensure_row_contiguousr   r   fastmetal_kernelr   s    r   _make_kl_forward_kernelr      sS    ?? ZFx 7,W"     r
   c                  t    t                      sd S d} t          j                            dg ddg| d          S )Na  
    constexpr int M = 4;
    constexpr int block = 1024 * M;
    constexpr int full_blocks = V / block;
    constexpr int extra = V - full_blocks * block;

    threadgroup float shared[32 * 2];

    uint out_idx = threadgroup_position_in_grid.y;
    uint simd_lane_id = thread_index_in_simdgroup;
    uint simd_group_id = simdgroup_index_in_threadgroup;

    logits_q += out_idx * V;
    logits_p += out_idx * V;
    out += out_idx * V;
    cotan += out_idx;

    float lse_q;
    float lse_p;

    {
        float max_q = -1e30;
        float max_p = -1e30;
        float sum_exp_q = 0;
        float sum_exp_p = 0;

        int offset = thread_index_in_threadgroup * M;
        for (int i = 0; i < full_blocks; i++) {
            // Read and update q and p
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j<M; j++) {
                vals_q[j] = logits_q[offset + j];
                vals_p[j] = logits_p[offset + j];
            }
            float prev_max_q = max_q;
            float prev_max_p = max_p;
            for (int j=0; j<M; j++) {
                max_q = max(max_q, vals_q[j]);
                max_p = max(max_p, vals_p[j]);
            }
            sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
            sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
            for (int j=0; j<M; j++) {
                sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
                sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
            }

            // Move to the next block
            offset += block;
        }
        if (extra > 0) {
            // Read and update q and p
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j < M; j++) {
                vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
                vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
            }
            float prev_max_q = max_q;
            float prev_max_p = max_p;
            for (int j=0; j<M; j++) {
                max_q = max(max_q, vals_q[j]);
                max_p = max(max_p, vals_p[j]);
            }
            sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
            sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
            for (int j=0; j<M; j++) {
                sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
                sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
            }
        }

        // Share the maxs across the threadgroup
        float prev_max_q = max_q;
        float prev_max_p = max_p;
        max_q = simd_max(max_q);
        max_p = simd_max(max_p);
        if (simd_lane_id == 0) {
            shared[simd_group_id * 2 + 0] = max_q;
            shared[simd_group_id * 2 + 1] = max_p;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        max_q = shared[simd_lane_id * 2 + 0];
        max_p = shared[simd_lane_id * 2 + 1];
        max_q = simd_max(max_q);
        max_p = simd_max(max_p);

        // Share the sum_exp across the threadgroup
        sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
        sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
        sum_exp_q = simd_sum(sum_exp_q);
        sum_exp_p = simd_sum(sum_exp_p);
        if (simd_lane_id == 0) {
            shared[simd_group_id * 2 + 0] = sum_exp_q;
            shared[simd_group_id * 2 + 1] = sum_exp_p;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        sum_exp_q = shared[simd_lane_id * 2 + 0];
        sum_exp_p = shared[simd_lane_id * 2 + 1];
        sum_exp_q = simd_sum(sum_exp_q);
        sum_exp_p = simd_sum(sum_exp_p);

        lse_p = max_p + metal::fast::log(sum_exp_p);
        lse_q = max_q + metal::fast::log(sum_exp_q);
    }

    threadgroup_barrier(mem_flags::mem_none);

    {
        float kl = 0;
        float c = cotan[0];

        int offset = thread_index_in_threadgroup * M;
        for (int i = 0; i < full_blocks; i++) {
            // Read and add to the kl
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j<M; j++) {
                vals_q[j] = logits_q[offset + j];
                vals_p[j] = logits_p[offset + j];
            }

            for (int j=0; j<M; j++) {
                out[offset + j] = static_cast<T>(
                    c * (metal::fast::exp(vals_q[j] - lse_q) - metal::fast::exp(vals_p[j] - lse_p)));
            }

            // Move to the next block
            offset += block;
        }
        if (extra > 0) {
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j<M; j++) {
                vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
                vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
            }

            for (int j=0; j<M; j++) {
                if (offset + j < V) {
                    out[offset + j] = static_cast<T>(
                        c * (metal::fast::exp(vals_q[j] - lse_q) - metal::fast::exp(vals_p[j] - lse_p)));
                }
            }
        }
    }
    kl_backward)r   r   cotanr   Tr   r   r   s    r   _make_kl_backward_kernelr       sT    ?? SFj 7555W"     r
   c           	          | j         | j        d         z  }| j        }t          | |g| j        d d         g|gd|fd| j        d         fgd|dfd          d         S 	NTV      r&   r'   r'   inputsoutput_shapesoutput_dtypestemplategridthreadgroupr   )sizeshapedtype_kl_forward_kernel)r   r   n_outsdts       r   _kl_div_lossr6   X  s    ]hnR00F	B(#~crc*+d)c8>"#567FA    	 	r
   c           	          | \  }}|j         }t          j        |          }t          |||g|j        g|gd|fd|j        d         fgd|j        dfd          d         }||fS )	Nr$   r%   r#   r&   r'   r(   r)   r   )r2   r   
zeros_like_kl_backward_kernelr1   r0   )primals	cotangentoutputr   r   r5   dpdqs           r   r6   r6   g  s     Hh	B	x	 	 B	(I.~&d)c8>"#567INA& 
 
 
 
	B r6Mr
   c           	          t                      rt          | |          S t          j                            | t          j        | dd          z
  |t          j        |dd          z
  dd          S )Nr#   TaxiskeepdimsnonerA   	reduction)r   r6   nnlosseskl_div_lossr   	logsumexp)r   r   s     r   rH   rH   y  su     
Hh///y$$r|H2EEEEr|H2EEEE	 % 
 
 	
r
   c                  v    t                      sd S d} t          j                            dddgddg| d          S )	NaT  
    constexpr int M = 4;
    constexpr int block = 1024 * M;
    constexpr int full_blocks = V / block;
    constexpr int extra = V - full_blocks * block;

    threadgroup float shared[32 * 2];

    uint out_idx = threadgroup_position_in_grid.y;
    uint simd_lane_id = thread_index_in_simdgroup;
    uint simd_group_id = simdgroup_index_in_threadgroup;

    logits_q += out_idx * V;
    logits_p += out_idx * V;
    out += out_idx;
    out_kl_q += out_idx;

    float lse_p;
    float lse_q;

    {
        float max_q = -1e30;
        float max_p = -1e30;
        float sum_exp_q = 0;
        float sum_exp_p = 0;

        int offset = thread_index_in_threadgroup * M;
        for (int i = 0; i < full_blocks; i++) {
            // Read and update q and p
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j<M; j++) {
                vals_q[j] = logits_q[offset + j];
                vals_p[j] = logits_p[offset + j];
            }
            float prev_max_q = max_q;
            float prev_max_p = max_p;
            for (int j=0; j<M; j++) {
                max_q = max(max_q, vals_q[j]);
                max_p = max(max_p, vals_p[j]);
            }
            sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
            sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
            for (int j=0; j<M; j++) {
                sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
                sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
            }

            // Move to the next block
            offset += block;
        }
        if (extra > 0) {
            // Read and update q and p
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j < M; j++) {
                vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
                vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
            }
            float prev_max_q = max_q;
            float prev_max_p = max_p;
            for (int j=0; j<M; j++) {
                max_q = max(max_q, vals_q[j]);
                max_p = max(max_p, vals_p[j]);
            }
            sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
            sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
            for (int j=0; j<M; j++) {
                sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
                sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
            }
        }

        // Share the maxs across the threadgroup
        float prev_max_q = max_q;
        float prev_max_p = max_p;
        max_q = simd_max(max_q);
        max_p = simd_max(max_p);
        if (simd_lane_id == 0) {
            shared[simd_group_id * 2 + 0] = max_q;
            shared[simd_group_id * 2 + 1] = max_p;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        max_q = shared[simd_lane_id * 2 + 0];
        max_p = shared[simd_lane_id * 2 + 1];
        max_q = simd_max(max_q);
        max_p = simd_max(max_p);

        // Share the sum_exp across the threadgroup
        sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
        sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
        sum_exp_q = simd_sum(sum_exp_q);
        sum_exp_p = simd_sum(sum_exp_p);
        if (simd_lane_id == 0) {
            shared[simd_group_id * 2 + 0] = sum_exp_q;
            shared[simd_group_id * 2 + 1] = sum_exp_p;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        sum_exp_q = shared[simd_lane_id * 2 + 0];
        sum_exp_p = shared[simd_lane_id * 2 + 1];
        sum_exp_q = simd_sum(sum_exp_q);
        sum_exp_p = simd_sum(sum_exp_p);

        lse_p = max_p + metal::fast::log(sum_exp_p);
        lse_q = max_q + metal::fast::log(sum_exp_q);
    }

    threadgroup_barrier(mem_flags::mem_none);

    {
        float kl_p = 0;
        float kl_q = 0;
        const float logtwo = metal::fast::log(static_cast<float>(2));

        int offset = thread_index_in_threadgroup * M;
        for (int i = 0; i < full_blocks; i++) {
            // Read and add to the kl_p and kl_q
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j<M; j++) {
                vals_q[j] = logits_q[offset + j];
                vals_p[j] = logits_p[offset + j];
            }

            for (int j=0; j<M; j++) {
                float logp_j = vals_p[j] - lse_p;
                float logq_j = vals_q[j] - lse_q;
                float p_j = metal::fast::exp(logp_j);
                float q_j = metal::fast::exp(logq_j);
                kl_p += p_j * (logtwo - metal::fast::log(1 + metal::fast::exp(logq_j - logp_j)));
                kl_q += q_j * (logtwo - metal::fast::log(1 + metal::fast::exp(logp_j - logq_j)));
            }

            // Move to the next block
            offset += block;
        }
        if (extra > 0) {
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j<M; j++) {
                vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
                vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
            }

            for (int j=0; j<M; j++) {
                float logp_j = vals_p[j] - lse_p;
                float logq_j = vals_q[j] - lse_q;
                float p_j = metal::fast::exp(logp_j);
                float q_j = metal::fast::exp(logq_j);
                kl_p += p_j * (logtwo - metal::fast::log(1 + metal::fast::exp(logq_j - logp_j)));
                kl_q += q_j * (logtwo - metal::fast::log(1 + metal::fast::exp(logp_j - logq_j)));
            }
        }

        // Add the kl_p and kl_q across the threadgroup
        kl_p = simd_sum(kl_p);
        kl_q = simd_sum(kl_q);
        if (simd_lane_id == 0) {
            shared[simd_group_id * 2 + 0] = kl_p;
            shared[simd_group_id * 2 + 1] = kl_q;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        kl_p = shared[simd_lane_id * 2 + 0];
        kl_q = shared[simd_lane_id * 2 + 1];
        kl_p = simd_sum(kl_p);
        kl_q = simd_sum(kl_q);

        if (thread_index_in_threadgroup == 0) {
            out[0] = static_cast<T>(0.5 * kl_p + 0.5 * kl_q);
            out_kl_q[0] = static_cast<T>(kl_q);
        }
    }
    
js_forwardr   r   r   out_kl_qTr   r   r   s    r   _make_js_forward_kernelrM     sV    ?? lF\ 7,Z("     r
   c                  t    t                      sd S d} t          j                            dg ddg| d          S )Na  
    constexpr int M = 4;
    constexpr int block = 1024 * M;
    constexpr int full_blocks = V / block;
    constexpr int extra = V - full_blocks * block;

    threadgroup float shared[32 * 2];

    uint out_idx = threadgroup_position_in_grid.y;
    uint simd_lane_id = thread_index_in_simdgroup;
    uint simd_group_id = simdgroup_index_in_threadgroup;

    logits_q += out_idx * V;
    logits_p += out_idx * V;
    out_q += out_idx * V;
    cotan += out_idx;
    output_kl_q += out_idx;

    float lse_q;
    float lse_p;

    {
        float max_q = -1e30;
        float max_p = -1e30;
        float sum_exp_q = 0;
        float sum_exp_p = 0;

        int offset = thread_index_in_threadgroup * M;
        for (int i = 0; i < full_blocks; i++) {
            // Read and update q and p
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j<M; j++) {
                vals_q[j] = logits_q[offset + j];
                vals_p[j] = logits_p[offset + j];
            }
            float prev_max_q = max_q;
            float prev_max_p = max_p;
            for (int j=0; j<M; j++) {
                max_q = max(max_q, vals_q[j]);
                max_p = max(max_p, vals_p[j]);
            }
            sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
            sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
            for (int j=0; j<M; j++) {
                sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
                sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
            }

            // Move to the next block
            offset += block;
        }
        if (extra > 0) {
            // Read and update q and p
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j < M; j++) {
                vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
                vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
            }
            float prev_max_q = max_q;
            float prev_max_p = max_p;
            for (int j=0; j<M; j++) {
                max_q = max(max_q, vals_q[j]);
                max_p = max(max_p, vals_p[j]);
            }
            sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
            sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
            for (int j=0; j<M; j++) {
                sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
                sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
            }
        }

        // Share the maxs across the threadgroup
        float prev_max_q = max_q;
        float prev_max_p = max_p;
        max_q = simd_max(max_q);
        max_p = simd_max(max_p);
        if (simd_lane_id == 0) {
            shared[simd_group_id * 2 + 0] = max_q;
            shared[simd_group_id * 2 + 1] = max_p;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        max_q = shared[simd_lane_id * 2 + 0];
        max_p = shared[simd_lane_id * 2 + 1];
        max_q = simd_max(max_q);
        max_p = simd_max(max_p);

        // Share the sum_exp across the threadgroup
        sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
        sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
        sum_exp_q = simd_sum(sum_exp_q);
        sum_exp_p = simd_sum(sum_exp_p);
        if (simd_lane_id == 0) {
            shared[simd_group_id * 2 + 0] = sum_exp_q;
            shared[simd_group_id * 2 + 1] = sum_exp_p;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
        sum_exp_q = shared[simd_lane_id * 2 + 0];
        sum_exp_p = shared[simd_lane_id * 2 + 1];
        sum_exp_q = simd_sum(sum_exp_q);
        sum_exp_p = simd_sum(sum_exp_p);

        lse_p = max_p + metal::fast::log(sum_exp_p);
        lse_q = max_q + metal::fast::log(sum_exp_q);
    }

    threadgroup_barrier(mem_flags::mem_none);

    {
        float c = cotan[0];
        const float logtwo = metal::fast::log(static_cast<float>(2));
        float kl_q = output_kl_q[0];

        int offset = thread_index_in_threadgroup * M;
        for (int i = 0; i < full_blocks; i++) {
            // Read and compute vjp for logits_q
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j<M; j++) {
                vals_q[j] = logits_q[offset + j];
                vals_p[j] = logits_p[offset + j];
            }

            for (int j=0; j<M; j++) {
                float logp_j = vals_p[j] - lse_p;
                float logq_j = vals_q[j] - lse_q;
                float q_j = metal::fast::exp(logq_j);
                out_q[offset + j] = static_cast<T>(
                    c * 0.5 * q_j * (logtwo - metal::fast::log(1 + metal::fast::exp(logp_j - logq_j)) - kl_q)
                );
            }

            // Move to the next block
            offset += block;
        }
        if (extra > 0) {
            float vals_q[M];
            float vals_p[M];
            for (int j=0; j<M; j++) {
                vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
                vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
            }

            for (int j=0; j<M; j++) {
                if (offset + j < V) {
                    float logp_j = vals_p[j] - lse_p;
                    float logq_j = vals_q[j] - lse_q;
                    float q_j = metal::fast::exp(logq_j);
                    out_q[offset + j] = static_cast<T>(
                        c * 0.5 * q_j * (logtwo - metal::fast::log(1 + metal::fast::exp(logp_j - logq_j)) - kl_q)
                    );
                }
            }
        }
    }
    js_backward)r   r   r   output_kl_qout_qTr   r   r   s    r   _make_js_backward_kernelrR   ?  sT    ?? ]F~ 7DDDY"     r
   c           	         | j         | j        d         z  }| j        }t          | |g| j        d d         | j        d d         g||gd|fd| j        d         fgd|dfd          }|d         t	          j        |d                   fS r"   )r0   r1   r2   _js_forward_kernelr   stop_gradient)r   r   r4   r5   outputss        r   _js_div_lossrW     s    ]hnR00F	B (#~crc*HN3B3,?@2h)c8>"#567FA   G 1:r'
3333r
   c           	          | \  }}|\  }}|\  }}|j         }t          j        |          }	t          ||||g|j        g|gd|fd|j        d         fgd|j        dfd          }
|
|	fS )Nr$   r%   r#   r&   r'   r(   r)   )r2   r   r8   _js_backward_kernelr1   r0   )r:   
cotangentsrV   r   r   r   _kl_qr5   r=   r>   s              r   rW   rW     s     HhHE1GAt	B	x	 	 B	(E40~&d)c8>"#567EJ" 
 
 
B r6Mr
   c                    t                      rt          | |          d         S |t          j        |dd          z
  }| t          j        | dd          z
  }|t          j        dt          j        ||z
            z             z   t          j        d                              | j                  z
  }t          j	        
                    ||dd          }t          j	        
                    ||dd          }d	||z   z  S )
Nr   r#   Tr@   r'      rC   rD   g      ?)r   rW   r   rI   logexpastyper2   rF   rG   rH   )r   r   
logprobs_p
logprobs_q
logprobs_mkl_pr\   s          r   js_div_lossrf     s     #Hh//22XB N N NN
XB N N NN
fQ
Z 788899:fQiix~../ 	
 y$$Z"PV$WWy$$Z"PV$WWdTk""r
   )mlx.corecorer   mlx.nnrF   r   r   r    r3   r9   custom_functionr6   vjprH   rM   rR   rT   rY   rW   rf   r	   r
   r   <module>rl      sm              E E Ee e eP^ ^ ^B -,.. ..00  	 	 	   "	
 	
 	
w w wth h hV -,.. ..00  4 4 4   $# # # # #r
   