Skip to content

vllm.v1.attention.backends.mla.rocm_aiter_mla

AiterMLAHelper

AITER MLA implementation requires num_heads >= 16. If num_heads < 16 and 16 % num_heads == 0, we can pad q to 16 heads; otherwise AITER has to fail.

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla.py
class AiterMLAHelper:
    """
    AITER MLA implementation requires num_heads >= 16. If num_heads < 16 and
    16 % num_heads == 0, we can pad q to 16 heads; otherwise AITER has to fail.
    """

    _AITER_MIN_MLA_HEADS: Final = 16
    _AITER_UNSUPPORTED_HEADS: ClassVar[tuple[int, ...]] = ()

    @staticmethod
    def check_num_heads_validity(num_heads: int):
        assert AiterMLAHelper.is_valid_num_heads(num_heads), (
            f"Aiter MLA requires that num_heads be multiples or divisors of 16, "
            f"but provided {num_heads} number of heads.\n"
            f"Try adjusting tensor_parallel_size value."
        )

    @staticmethod
    def is_valid_num_heads(num_heads: int) -> bool:
        return (
            num_heads % AiterMLAHelper._AITER_MIN_MLA_HEADS == 0
            if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
            else AiterMLAHelper._AITER_MIN_MLA_HEADS % num_heads == 0
        )

    @staticmethod
    def get_actual_mla_num_heads(num_heads: int) -> int:
        return max(num_heads, AiterMLAHelper._AITER_MIN_MLA_HEADS)

    @staticmethod
    def get_mla_padded_q(num_heads: int, q: torch.Tensor) -> torch.Tensor:
        return (
            q
            if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
            else q.repeat_interleave(
                AiterMLAHelper._AITER_MIN_MLA_HEADS // num_heads, dim=1
            )
        )

    @staticmethod
    def get_mla_unpadded_o(num_heads: int, o: torch.Tensor) -> torch.Tensor:
        return (
            o
            if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
            else o[:, :: AiterMLAHelper._AITER_MIN_MLA_HEADS // num_heads, :]
        )

AiterMLAImpl

Bases: MLACommonImpl[AiterMLAMetadata]

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla.py
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
        kv_cache_dtype: str,
        logits_soft_cap: float | None,
        attn_type: str,
        kv_sharing_target_layer_name: str | None,
        # MLA Specific Arguments
        **mla_args,
    ) -> None:
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **mla_args,
        )
        AiterMLAHelper.check_num_heads_validity(num_heads)

        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
        if any(unsupported_features):
            raise NotImplementedError(
                "Aiter MLA does not support one of the following: "
                "alibi_slopes, sliding_window, logits_soft_cap"
            )

        from aiter import flash_attn_varlen_func

        self.flash_attn_varlen_func = flash_attn_varlen_func

        # FP8 MLA prefill kernel imports (lazy, only when enabled).
        # Auto-enabled on gfx950 when AITER ships the kernels.
        self._fp8_prefill_enabled = _fp8_mla_prefill_supported()
        if self._fp8_prefill_enabled:
            from aiter import mla_prefill_ps_asm_fwd, mla_reduce_v1

            self._mla_prefill_ps_asm_fwd = mla_prefill_ps_asm_fwd
            self._mla_reduce_v1 = mla_reduce_v1

    def _flash_attn_varlen_diff_headdims(
        self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
    ):
        output = self.flash_attn_varlen_func(  # type: ignore[call-arg]
            q=q,
            k=k,
            v=v,
            softmax_scale=softmax_scale,
            return_lse=return_softmax_lse,
            **kwargs,
        )

        return output

    def _mla_fp8_prefill_attn(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        attn_metadata: AiterMLAMetadata,
        out: torch.Tensor,
    ) -> None:
        """Run FP8 MLA prefill via mla_prefill_ps_asm_fwd + mla_reduce_v1.

        Q, K, V are already decompressed (post-kv_b_proj), so K and V have
        ``num_heads`` heads (same as Q) and gqa_ratio=1.  Writes the
        result in-place to ``out``, which is the [total_q, nhead * v_head_dim]
        output buffer supplied by ``forward_mha``; no extra allocation or
        copy is required.
        """
        from vllm.platforms import current_platform
        from vllm.v1.worker.workspace import current_workspace_manager

        fp8_dtype = current_platform.fp8_dtype()
        total_q = q.shape[0]
        nhead = self.num_heads
        v_head_dim = self.v_head_dim
        tile_q = _FP8_PREFILL_TILE_Q

        # The FP8 ASM kernel expects FP8 inputs; the q_scale/k_scale/v_scale
        # parameters select per-tensor dequant scales.  Q/K/V arrive as
        # bf16 from kv_b_proj, so cast here (one_scale=1.0 disables scaling).
        if q.dtype != fp8_dtype:
            q = q.to(fp8_dtype)
        if k.dtype != fp8_dtype:
            k = k.to(fp8_dtype)
        if v.dtype != fp8_dtype:
            v = v.to(fp8_dtype)

        one_scale = torch.ones((), dtype=torch.float32, device=q.device)

        # num_partial_tiles is resolved during metadata build to avoid an
        # in-forward .item() sync that would prevent CUDA Graph capture.
        # forward_mha gates the FP8 path on fp8_prefill_qo_indptr being set,
        # and the builder always sets every fp8_prefill_* field together, so
        # num_partial_tiles is non-None here.
        num_partial_tiles = attn_metadata.fp8_prefill_num_partial_tiles
        assert num_partial_tiles is not None

        # Reuse the caller's output buffer to skip the per-call alloc + copy.
        # The ASM and reduce kernels both write to a [total_q, nhead, v_head_dim]
        # view, which aliases the [total_q, nhead * v_head_dim] storage of out.
        out_3d = out.view(total_q, nhead, v_head_dim)

        # Per-call scratch (logits, attn_lse, final_lse) is served from the
        # workspace manager so allocator churn in the prefill hot path is
        # bounded after warmup, matching the pattern in PR #41002.
        logits, attn_lse, final_lse = current_workspace_manager().get_simultaneous(
            ((num_partial_tiles * tile_q, nhead, v_head_dim), torch.float32),
            ((num_partial_tiles * tile_q, nhead), torch.float32),
            ((total_q, nhead), torch.float32),
        )

        # Phase 1: persistent-scheduling assembly prefill kernel.
        self._mla_prefill_ps_asm_fwd(
            q,
            k,
            v,
            attn_metadata.fp8_prefill_qo_indptr,
            attn_metadata.fp8_prefill_kv_indptr,
            attn_metadata.fp8_prefill_kv_indices,
            attn_metadata.fp8_prefill_work_indptr,
            attn_metadata.fp8_prefill_work_info_set,
            attn_metadata.fp8_prefill_max_q_len,
            self.scale,
            True,  # is_causal
            logits,
            attn_lse,
            out_3d,
            one_scale,
            one_scale,
            one_scale,
        )

        # Phase 2: reduction across KV splits.
        self._mla_reduce_v1(
            logits,
            attn_lse,
            attn_metadata.fp8_prefill_reduce_indptr,
            attn_metadata.fp8_prefill_reduce_final_map,
            attn_metadata.fp8_prefill_reduce_partial_map,
            tile_q,
            out_3d,
            final_lse,
        )

    def forward_mha(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
        k_scale: torch.Tensor,
        output: torch.Tensor,
    ) -> None:
        """Dispatch prefill to the FP8 ASM kernel when available.

        Falls back to the parent (``flash_attn_varlen_func``) when FP8
        MLA prefill is disabled, PS metadata is missing, or chunked
        context requires two-pass merge.

        The annotation uses the base ``MLACommonMetadata`` to honour LSP
        with ``MLACommonImpl.forward_mha``; the AITER builder always
        produces ``AiterMLAMetadata`` instances at runtime, so we narrow
        with ``isinstance`` before reading the AITER-specific FP8 fields.
        """
        if (
            not self._fp8_prefill_enabled
            or not isinstance(attn_metadata, AiterMLAMetadata)
            or attn_metadata.fp8_prefill_qo_indptr is None
        ):
            return super().forward_mha(
                q,
                kv_c_normed,
                k_pe,
                kv_c_and_k_pe_cache,
                attn_metadata,
                k_scale,
                output,
            )

        assert attn_metadata.prefill is not None
        prefill_metadata = attn_metadata.prefill
        has_context = prefill_metadata.chunked_context is not None

        if has_context:
            return super().forward_mha(
                q,
                kv_c_normed,
                k_pe,
                kv_c_and_k_pe_cache,
                attn_metadata,
                k_scale,
                output,
            )

        kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
            -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
        )
        k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
        k = self._concat_k_nope_k_pe(k_nope, k_pe)

        self._mla_fp8_prefill_attn(q, k, v, attn_metadata, output)

    def forward_mqa(
        self,
        q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: AiterMLAMetadata,
        layer: AttentionLayer,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        assert kv_c_and_k_pe_cache.numel() > 0
        assert attn_metadata.decode is not None
        assert attn_metadata.decode.max_qo_len is not None

        if type(q) is tuple:
            q = torch.cat(q, dim=-1)

        assert isinstance(q, torch.Tensor)
        B = q.shape[0]

        mla_padded_q = AiterMLAHelper.get_mla_padded_q(self.num_heads, q)
        mla_num_heads = AiterMLAHelper.get_actual_mla_num_heads(self.num_heads)
        o = torch.empty(
            B,
            mla_num_heads,
            self.kv_lora_rank,
            dtype=attn_metadata.decode.attn_out_dtype,
            device=q.device,
        )

        kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)

        # Build kwargs for mla_decode_fwd. Pass persistent metadata only
        # when it was successfully computed (qseqlen=1 decode steps).
        # For multi-token verification steps (spec-dec), the kernel falls
        # back to computing metadata internally.
        mla_kwargs = dict(
            q_scale=layer._q_scale,
            kv_scale=layer._k_scale,
        )
        if attn_metadata.work_meta_data is not None:
            mla_kwargs.update(
                work_meta_data=attn_metadata.work_meta_data,
                work_indptr=attn_metadata.work_indptr,
                work_info_set=attn_metadata.work_info_set,
                reduce_indptr=attn_metadata.reduce_indptr,
                reduce_final_map=attn_metadata.reduce_final_map,
                reduce_partial_map=attn_metadata.reduce_partial_map,
            )

        rocm_aiter_ops.mla_decode_fwd(
            mla_padded_q,
            kv_buffer,
            o,
            self.scale,
            attn_metadata.decode.qo_indptr,
            attn_metadata.decode.max_qo_len,
            attn_metadata.decode.paged_kv_indptr,
            attn_metadata.decode.paged_kv_indices,
            attn_metadata.decode.paged_kv_last_page_len,
            **mla_kwargs,
        )

        return AiterMLAHelper.get_mla_unpadded_o(self.num_heads, o), None

_mla_fp8_prefill_attn

_mla_fp8_prefill_attn(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attn_metadata: AiterMLAMetadata,
    out: Tensor,
) -> None

Run FP8 MLA prefill via mla_prefill_ps_asm_fwd + mla_reduce_v1.

Q, K, V are already decompressed (post-kv_b_proj), so K and V have num_heads heads (same as Q) and gqa_ratio=1. Writes the result in-place to out, which is the [total_q, nhead * v_head_dim] output buffer supplied by forward_mha; no extra allocation or copy is required.

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla.py
def _mla_fp8_prefill_attn(
    self,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    attn_metadata: AiterMLAMetadata,
    out: torch.Tensor,
) -> None:
    """Run FP8 MLA prefill via mla_prefill_ps_asm_fwd + mla_reduce_v1.

    Q, K, V are already decompressed (post-kv_b_proj), so K and V have
    ``num_heads`` heads (same as Q) and gqa_ratio=1.  Writes the
    result in-place to ``out``, which is the [total_q, nhead * v_head_dim]
    output buffer supplied by ``forward_mha``; no extra allocation or
    copy is required.
    """
    from vllm.platforms import current_platform
    from vllm.v1.worker.workspace import current_workspace_manager

    fp8_dtype = current_platform.fp8_dtype()
    total_q = q.shape[0]
    nhead = self.num_heads
    v_head_dim = self.v_head_dim
    tile_q = _FP8_PREFILL_TILE_Q

    # The FP8 ASM kernel expects FP8 inputs; the q_scale/k_scale/v_scale
    # parameters select per-tensor dequant scales.  Q/K/V arrive as
    # bf16 from kv_b_proj, so cast here (one_scale=1.0 disables scaling).
    if q.dtype != fp8_dtype:
        q = q.to(fp8_dtype)
    if k.dtype != fp8_dtype:
        k = k.to(fp8_dtype)
    if v.dtype != fp8_dtype:
        v = v.to(fp8_dtype)

    one_scale = torch.ones((), dtype=torch.float32, device=q.device)

    # num_partial_tiles is resolved during metadata build to avoid an
    # in-forward .item() sync that would prevent CUDA Graph capture.
    # forward_mha gates the FP8 path on fp8_prefill_qo_indptr being set,
    # and the builder always sets every fp8_prefill_* field together, so
    # num_partial_tiles is non-None here.
    num_partial_tiles = attn_metadata.fp8_prefill_num_partial_tiles
    assert num_partial_tiles is not None

    # Reuse the caller's output buffer to skip the per-call alloc + copy.
    # The ASM and reduce kernels both write to a [total_q, nhead, v_head_dim]
    # view, which aliases the [total_q, nhead * v_head_dim] storage of out.
    out_3d = out.view(total_q, nhead, v_head_dim)

    # Per-call scratch (logits, attn_lse, final_lse) is served from the
    # workspace manager so allocator churn in the prefill hot path is
    # bounded after warmup, matching the pattern in PR #41002.
    logits, attn_lse, final_lse = current_workspace_manager().get_simultaneous(
        ((num_partial_tiles * tile_q, nhead, v_head_dim), torch.float32),
        ((num_partial_tiles * tile_q, nhead), torch.float32),
        ((total_q, nhead), torch.float32),
    )

    # Phase 1: persistent-scheduling assembly prefill kernel.
    self._mla_prefill_ps_asm_fwd(
        q,
        k,
        v,
        attn_metadata.fp8_prefill_qo_indptr,
        attn_metadata.fp8_prefill_kv_indptr,
        attn_metadata.fp8_prefill_kv_indices,
        attn_metadata.fp8_prefill_work_indptr,
        attn_metadata.fp8_prefill_work_info_set,
        attn_metadata.fp8_prefill_max_q_len,
        self.scale,
        True,  # is_causal
        logits,
        attn_lse,
        out_3d,
        one_scale,
        one_scale,
        one_scale,
    )

    # Phase 2: reduction across KV splits.
    self._mla_reduce_v1(
        logits,
        attn_lse,
        attn_metadata.fp8_prefill_reduce_indptr,
        attn_metadata.fp8_prefill_reduce_final_map,
        attn_metadata.fp8_prefill_reduce_partial_map,
        tile_q,
        out_3d,
        final_lse,
    )

forward_mha

forward_mha(
    q: Tensor,
    kv_c_normed: Tensor,
    k_pe: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    attn_metadata: MLACommonMetadata,
    k_scale: Tensor,
    output: Tensor,
) -> None

Dispatch prefill to the FP8 ASM kernel when available.

Falls back to the parent (flash_attn_varlen_func) when FP8 MLA prefill is disabled, PS metadata is missing, or chunked context requires two-pass merge.

The annotation uses the base MLACommonMetadata to honour LSP with MLACommonImpl.forward_mha; the AITER builder always produces AiterMLAMetadata instances at runtime, so we narrow with isinstance before reading the AITER-specific FP8 fields.

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla.py
def forward_mha(
    self,
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    attn_metadata: MLACommonMetadata,
    k_scale: torch.Tensor,
    output: torch.Tensor,
) -> None:
    """Dispatch prefill to the FP8 ASM kernel when available.

    Falls back to the parent (``flash_attn_varlen_func``) when FP8
    MLA prefill is disabled, PS metadata is missing, or chunked
    context requires two-pass merge.

    The annotation uses the base ``MLACommonMetadata`` to honour LSP
    with ``MLACommonImpl.forward_mha``; the AITER builder always
    produces ``AiterMLAMetadata`` instances at runtime, so we narrow
    with ``isinstance`` before reading the AITER-specific FP8 fields.
    """
    if (
        not self._fp8_prefill_enabled
        or not isinstance(attn_metadata, AiterMLAMetadata)
        or attn_metadata.fp8_prefill_qo_indptr is None
    ):
        return super().forward_mha(
            q,
            kv_c_normed,
            k_pe,
            kv_c_and_k_pe_cache,
            attn_metadata,
            k_scale,
            output,
        )

    assert attn_metadata.prefill is not None
    prefill_metadata = attn_metadata.prefill
    has_context = prefill_metadata.chunked_context is not None

    if has_context:
        return super().forward_mha(
            q,
            kv_c_normed,
            k_pe,
            kv_c_and_k_pe_cache,
            attn_metadata,
            k_scale,
            output,
        )

    kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
        -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
    )
    k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
    k = self._concat_k_nope_k_pe(k_nope, k_pe)

    self._mla_fp8_prefill_attn(q, k, v, attn_metadata, output)

AiterMLAMetadataBuilder

Bases: MLACommonMetadataBuilder[AiterMLAMetadata]

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
    # TODO(luka, lucas): audit this as part of:
    #  https://github.com/vllm-project/vllm/issues/22945
    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
    query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM

    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        super().__init__(
            kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
        )

        self.compilation_config = vllm_config.compilation_config
        self.decode_attn_out_dtype = vllm_config.model_config.dtype

        # Store the kernel block size from the spec. When kernel_block_size=1
        # (no spec-dec), behavior is identical to the original. When > 1
        # (e.g. 16 with Eagle3), we expand block-level indices into per-token
        # flat indices since the aiter kernel always uses page_size=1 internally.
        self.kernel_block_size = kv_cache_spec.block_size

        # In the flat view (.view(-1,1,1,H)), each token is its own page,
        # so max_num_pages_per_req = max_model_len regardless of
        # kernel_block_size.
        max_num_pages_per_req = vllm_config.model_config.max_model_len
        max_num_reqs = vllm_config.scheduler_config.max_num_seqs
        max_num_pages = max_num_reqs * max_num_pages_per_req

        # Preparing persistent buffers
        # TODO: we can disambiguate between decode and mixed-prefill decode here
        # so we can only use the persistent buffer if a cudagraph is actually
        # being used.

        # paged_kv_last_page_len is always 1s (the aiter kernel always sees
        # page_size=1 after .view(-1,1,1,H) flattening), so we create it
        # once and reuse slices in both eager and cudagraph modes.
        self.paged_kv_last_page_len = torch.ones(
            max_num_reqs, dtype=torch.int32, device=device
        )

        # Persistent buffer for paged_kv_indices to avoid blocking boolean mask
        # indexing (block_table_tensor[mask]) which has data-dependent output size.
        self.paged_kv_indices = torch.zeros(
            max_num_pages, dtype=torch.int32, device=device
        )

        from aiter import dtypes, get_mla_metadata_info_v1

        # For num_attention_heads < 16 (e.g. kimi-k2.5 head=8 with TP8),
        # make sure get_mla_metadata_info_v1 / get_mla_metadata_v1 are consistent
        # with the actual tensor shape passed to mla_decode_fwd.
        self._num_attention_heads = max(16, self.num_heads)
        q_dtype = self.decode_attn_out_dtype
        kv_cache_dtype_str = getattr(vllm_config.cache_config, "cache_dtype", "auto")
        if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"):
            kv_cache_dtype_str = "fp8"
        else:
            kv_cache_dtype_str = "bf16"
        kv_dtype = dtypes.d_dtypes.get(kv_cache_dtype_str, dtypes.bf16)
        (
            (work_meta_data_size, work_meta_data_type),
            (work_indptr_size, work_indptr_type),
            (work_info_set_size, work_info_set_type),
            (reduce_indptr_size, reduce_indptr_type),
            (reduce_final_map_size, reduce_final_map_type),
            (reduce_partial_map_size, reduce_partial_map_type),
        ) = get_mla_metadata_info_v1(
            max_num_reqs,
            1,
            self._num_attention_heads,
            q_dtype,
            kv_dtype,
            is_sparse=False,
            fast_mode=True,
        )
        self._mla_work_meta_data = torch.empty(
            work_meta_data_size, dtype=work_meta_data_type, device=device
        )
        self._mla_work_indptr = torch.empty(
            work_indptr_size, dtype=work_indptr_type, device=device
        )
        self._mla_work_info_set = torch.empty(
            work_info_set_size, dtype=work_info_set_type, device=device
        )
        self._mla_reduce_indptr = torch.empty(
            reduce_indptr_size, dtype=reduce_indptr_type, device=device
        )
        self._mla_reduce_final_map = torch.empty(
            reduce_final_map_size, dtype=reduce_final_map_type, device=device
        )
        self._mla_reduce_partial_map = torch.empty(
            reduce_partial_map_size,
            dtype=reduce_partial_map_type,
            device=device,
        )

        self._fp8_prefill_enabled = _fp8_mla_prefill_supported()
        if self._fp8_prefill_enabled:
            max_prefill_qlen = min(
                vllm_config.model_config.max_model_len,
                vllm_config.scheduler_config.max_num_batched_tokens,
            )
            self._init_fp8_prefill_ps_buffers(max_num_reqs, max_prefill_qlen, device)

        if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
            self.paged_kv_indptr = torch.zeros(
                max_num_reqs + 1, dtype=torch.int32, device=device
            )

            self.qo_indptr = torch.zeros(
                max_num_reqs + 1, dtype=torch.int32, device=device
            )

    def _init_fp8_prefill_ps_buffers(
        self,
        max_num_reqs: int,
        max_prefill_qlen: int,
        device: torch.device,
    ) -> None:
        """Pre-allocate persistent buffers for FP8 MLA prefill PS metadata.

        Uses ``get_ps_metadata_info_v1`` with max values so the buffers are
        large enough for any batch.  ``get_ps_metadata_v1`` fills them
        per-batch in ``build()``.

        Args:
            max_num_reqs: Maximum number of concurrent requests.
            max_prefill_qlen: Maximum Q-length for a single request in one
                prefill batch.  Should be ``min(max_model_len,
                max_num_batched_tokens)`` — the chunked-prefill scheduler
                never emits more than ``max_num_batched_tokens`` new tokens
                per batch.
            device: Target device for the buffers.
        """
        from aiter import get_ps_metadata_info_v1

        # After kv_b_proj decompression, K has num_heads heads (same as Q).
        # So gqa_ratio=1 and num_head_k=num_heads for the PS kernel.
        num_head_k = self.num_heads
        # gqa_ratio = 1
        # qlen_granularity = _FP8_PREFILL_TILE_Q // max(gqa_ratio, 1)
        qlen_granularity = _FP8_PREFILL_TILE_Q

        (
            (work_metadata_size, work_metadata_dtype),
            (work_indptr_size, work_indptr_dtype),
            (work_info_size, work_info_dtype),
            (reduce_indptr_size, reduce_indptr_dtype),
            (reduce_final_map_size, reduce_final_map_dtype),
            (reduce_partial_map_size, reduce_partial_map_dtype),
        ) = get_ps_metadata_info_v1(
            batch_size=max_num_reqs,
            num_head_k=num_head_k,
            max_qlen=max_prefill_qlen,
            qlen_granularity=qlen_granularity,
        )

        self.fp8_ps_work_metadata = torch.empty(
            work_metadata_size, dtype=work_metadata_dtype, device=device
        )
        self.fp8_ps_work_indptr = torch.empty(
            work_indptr_size, dtype=work_indptr_dtype, device=device
        )
        self.fp8_ps_work_info = torch.empty(
            *work_info_size, dtype=work_info_dtype, device=device
        )
        self.fp8_ps_reduce_indptr = torch.empty(
            reduce_indptr_size, dtype=reduce_indptr_dtype, device=device
        )
        self.fp8_ps_reduce_final_map = torch.empty(
            *reduce_final_map_size, dtype=reduce_final_map_dtype, device=device
        )
        self.fp8_ps_reduce_partial_map = torch.empty(
            reduce_partial_map_size,
            dtype=reduce_partial_map_dtype,
            device=device,
        )

        logger.info(
            "FP8 MLA prefill PS buffers allocated "
            "(max_batch=%d, max_qlen=%d, num_head_k=%d)",
            max_num_reqs,
            max_prefill_qlen,
            num_head_k,
        )

    def _build_fp8_prefill_ps_metadata(
        self,
        metadata: AiterMLAMetadata,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> None:
        """Build per-batch FP8 MLA prefill PS metadata and attach to *metadata*.

        Called from ``build()`` when prefill tokens are present and
        FP8 MLA prefill is enabled (auto-detected via
        ``_fp8_mla_prefill_supported()``).
        """
        from aiter import get_ps_metadata_v1

        prefill = metadata.prefill
        # Caller (build()) only invokes this when prefill tokens exist, so
        # metadata.prefill is guaranteed non-None.  Assert to narrow for mypy.
        assert prefill is not None
        qo_indptr = prefill.query_start_loc
        kv_indptr = qo_indptr  # new tokens: KV length == Q length

        # Reuse the existing CPU view of query_start_loc instead of forcing a
        # device->host copy.  Prefill batches sit at the tail of the request
        # list, so we slice from num_decodes onwards and rebase to zero, the
        # same transform the parent build applies on device tensors.
        num_decodes = metadata.num_decodes
        qsl_cpu = common_attn_metadata.query_start_loc_cpu
        qo_indptr_cpu = (qsl_cpu[num_decodes:] - qsl_cpu[num_decodes]).to(torch.int32)
        kv_indptr_cpu = qo_indptr_cpu.clone()
        seq_lens_cpu = (qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]).to(torch.int32)

        num_head_k = self.num_heads
        # gqa_ratio = 1
        # qhead_granularity = max(gqa_ratio, 1)
        # qlen_granularity = _FP8_PREFILL_TILE_Q // qhead_granularity
        gqa_ratio = 1
        qhead_granularity = 1
        qlen_granularity = _FP8_PREFILL_TILE_Q
        kvlen_granularity = 128
        block_size = 1  # non-paged: each "page" is one token

        get_ps_metadata_v1(
            qo_indptr_cpu,
            kv_indptr_cpu,
            seq_lens_cpu,
            gqa_ratio,
            num_head_k,
            self.fp8_ps_work_metadata,
            self.fp8_ps_work_indptr,
            self.fp8_ps_work_info,
            self.fp8_ps_reduce_indptr,
            self.fp8_ps_reduce_final_map,
            self.fp8_ps_reduce_partial_map,
            qhead_granularity=qhead_granularity,
            qlen_granularity=qlen_granularity,
            kvlen_granularity=kvlen_granularity,
            block_size=block_size,
            is_causal=True,
        )

        total_prefill_tokens = int(qo_indptr_cpu[-1].item())
        kv_indices = torch.arange(
            total_prefill_tokens, device=qo_indptr.device, dtype=torch.int32
        )

        # The actual number of active partial tiles for this batch is the
        # final value of reduce_indptr.  Resolving it here (during metadata
        # build) keeps it off the per-layer forward path where a sync would
        # break CUDA Graph capture.  Using the device-side reduce_indptr is
        # acceptable since build is allowed to incur an occasional sync.
        num_partial_tiles = int(self.fp8_ps_reduce_indptr[-1].item())

        # Attach PS metadata to the metadata object so forward_mha can read it.
        metadata.fp8_prefill_qo_indptr = qo_indptr
        metadata.fp8_prefill_kv_indptr = kv_indptr
        metadata.fp8_prefill_kv_indices = kv_indices
        metadata.fp8_prefill_work_indptr = self.fp8_ps_work_indptr
        metadata.fp8_prefill_work_info_set = self.fp8_ps_work_info
        metadata.fp8_prefill_reduce_indptr = self.fp8_ps_reduce_indptr
        metadata.fp8_prefill_reduce_final_map = self.fp8_ps_reduce_final_map
        metadata.fp8_prefill_reduce_partial_map = self.fp8_ps_reduce_partial_map
        metadata.fp8_prefill_max_q_len = prefill.max_query_len
        metadata.fp8_prefill_num_partial_tiles = num_partial_tiles

    def _build_decode(
        self,
        block_table_tensor: torch.Tensor,
        seq_lens_device: torch.Tensor,
        max_seq_len: int,
        query_start_loc_cpu: torch.Tensor,
        query_start_loc_device: torch.Tensor,
        num_decode_tokens: int,
        dcp_tot_seq_lens_device: torch.Tensor | None,
    ) -> AiterMLADecodeMetadata:
        device = self.device
        num_reqs = seq_lens_device.size(0)

        # The aiter kernel always operates with page_size=1 (the wrapper
        # flattens kv_buffer). last_page_len is always 1.
        paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]

        # indptr: cumsum of seq_lens (one page per token in the flat view)
        paged_kv_indptr = torch.cat(
            [
                torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
                seq_lens_device.cumsum(dim=0, dtype=torch.int32),
            ]
        )
        qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
        max_qo_len = qo_len.max().item()

        if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
            self.paged_kv_indices.fill_(-1)

        # Expand block_table entries into per-token flat indices.
        # When kernel_block_size=1, this degrades to a direct copy (identical
        # to the original _copy_page_indices_kernel).
        # When kernel_block_size=K>1, block_table entry b covering K tokens
        # gets expanded to flat indices b*K, b*K+1, ..., b*K+(K-1).
        _expand_page_indices_kernel[(num_reqs,)](
            self.paged_kv_indices,
            block_table_tensor,
            block_table_tensor.stride(0),
            paged_kv_indptr,
            seq_lens_device,
            KERNEL_BLOCK_SIZE=self.kernel_block_size,
            BLOCK_SIZE=1024,
        )
        paged_kv_indices = self.paged_kv_indices

        if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
            self.paged_kv_indptr[: 1 + num_reqs].copy_(
                paged_kv_indptr, non_blocking=True
            )
            self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1])
            paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs]

            # paged_kv_last_page_len already uses the pre-initialized buffer slice
            # (set above), so no copy needed - buffer is always 1s.

            self.qo_indptr[: 1 + num_reqs].copy_(
                query_start_loc_device, non_blocking=True
            )
            self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1]
            qo_indptr = self.qo_indptr[: 1 + num_reqs]

        else:
            qo_indptr = torch.arange(
                0, num_reqs + 1, step=1, dtype=torch.int32, device=device
            )

        # The aiter MLA ASM kernel only supports qseqlen=1 (single-token
        # decode). With speculative decoding, the verification step has
        # qseqlen > 1 (e.g. 8 for spec7). get_mla_metadata_v1 calls
        # get_heuristic_kernel_mla which fails for qseqlen > 1.
        # We track whether persistent metadata was successfully computed
        # so forward_mqa can skip passing it (falling back to the kernel
        # computing its own metadata internally, like v0.18.0).
        has_persistent_metadata = False
        if max_qo_len == 1:
            from aiter import get_mla_metadata_v1

            get_mla_metadata_v1(
                qo_indptr,
                paged_kv_indptr,
                paged_kv_last_page_len,
                self._num_attention_heads,
                1,
                True,
                self._mla_work_meta_data,
                self._mla_work_info_set,
                self._mla_work_indptr,
                self._mla_reduce_indptr,
                self._mla_reduce_final_map,
                self._mla_reduce_partial_map,
                page_size=1,
                kv_granularity=16,
                max_seqlen_qo=max_qo_len,
                uni_seqlen_qo=max_qo_len,
                fast_mode=True,
            )
            has_persistent_metadata = True

        attn_metadata = AiterMLADecodeMetadata(
            block_table=block_table_tensor,
            seq_lens=seq_lens_device,
            paged_kv_indptr=paged_kv_indptr,
            paged_kv_indices=paged_kv_indices,
            paged_kv_last_page_len=paged_kv_last_page_len,
            qo_indptr=qo_indptr,
            dcp_tot_seq_lens=dcp_tot_seq_lens_device,
            max_qo_len=max_qo_len,
            attn_out_dtype=self.decode_attn_out_dtype,
            has_persistent_metadata=has_persistent_metadata,
        )

        return attn_metadata

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> AiterMLAMetadata:
        attn_metadata = super().build(
            common_prefix_len, common_attn_metadata, fast_build
        )
        if (
            attn_metadata.decode is not None
            and attn_metadata.decode.has_persistent_metadata
        ):
            attn_metadata.work_meta_data = self._mla_work_meta_data
            attn_metadata.work_indptr = self._mla_work_indptr
            attn_metadata.work_info_set = self._mla_work_info_set
            attn_metadata.reduce_indptr = self._mla_reduce_indptr
            attn_metadata.reduce_final_map = self._mla_reduce_final_map
            attn_metadata.reduce_partial_map = self._mla_reduce_partial_map
        if self._fp8_prefill_enabled and attn_metadata.prefill is not None:
            self._build_fp8_prefill_ps_metadata(attn_metadata, common_attn_metadata)
        return attn_metadata

_build_fp8_prefill_ps_metadata

_build_fp8_prefill_ps_metadata(
    metadata: AiterMLAMetadata,
    common_attn_metadata: CommonAttentionMetadata,
) -> None

Build per-batch FP8 MLA prefill PS metadata and attach to metadata.

Called from build() when prefill tokens are present and FP8 MLA prefill is enabled (auto-detected via _fp8_mla_prefill_supported()).

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla.py
def _build_fp8_prefill_ps_metadata(
    self,
    metadata: AiterMLAMetadata,
    common_attn_metadata: CommonAttentionMetadata,
) -> None:
    """Build per-batch FP8 MLA prefill PS metadata and attach to *metadata*.

    Called from ``build()`` when prefill tokens are present and
    FP8 MLA prefill is enabled (auto-detected via
    ``_fp8_mla_prefill_supported()``).
    """
    from aiter import get_ps_metadata_v1

    prefill = metadata.prefill
    # Caller (build()) only invokes this when prefill tokens exist, so
    # metadata.prefill is guaranteed non-None.  Assert to narrow for mypy.
    assert prefill is not None
    qo_indptr = prefill.query_start_loc
    kv_indptr = qo_indptr  # new tokens: KV length == Q length

    # Reuse the existing CPU view of query_start_loc instead of forcing a
    # device->host copy.  Prefill batches sit at the tail of the request
    # list, so we slice from num_decodes onwards and rebase to zero, the
    # same transform the parent build applies on device tensors.
    num_decodes = metadata.num_decodes
    qsl_cpu = common_attn_metadata.query_start_loc_cpu
    qo_indptr_cpu = (qsl_cpu[num_decodes:] - qsl_cpu[num_decodes]).to(torch.int32)
    kv_indptr_cpu = qo_indptr_cpu.clone()
    seq_lens_cpu = (qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]).to(torch.int32)

    num_head_k = self.num_heads
    # gqa_ratio = 1
    # qhead_granularity = max(gqa_ratio, 1)
    # qlen_granularity = _FP8_PREFILL_TILE_Q // qhead_granularity
    gqa_ratio = 1
    qhead_granularity = 1
    qlen_granularity = _FP8_PREFILL_TILE_Q
    kvlen_granularity = 128
    block_size = 1  # non-paged: each "page" is one token

    get_ps_metadata_v1(
        qo_indptr_cpu,
        kv_indptr_cpu,
        seq_lens_cpu,
        gqa_ratio,
        num_head_k,
        self.fp8_ps_work_metadata,
        self.fp8_ps_work_indptr,
        self.fp8_ps_work_info,
        self.fp8_ps_reduce_indptr,
        self.fp8_ps_reduce_final_map,
        self.fp8_ps_reduce_partial_map,
        qhead_granularity=qhead_granularity,
        qlen_granularity=qlen_granularity,
        kvlen_granularity=kvlen_granularity,
        block_size=block_size,
        is_causal=True,
    )

    total_prefill_tokens = int(qo_indptr_cpu[-1].item())
    kv_indices = torch.arange(
        total_prefill_tokens, device=qo_indptr.device, dtype=torch.int32
    )

    # The actual number of active partial tiles for this batch is the
    # final value of reduce_indptr.  Resolving it here (during metadata
    # build) keeps it off the per-layer forward path where a sync would
    # break CUDA Graph capture.  Using the device-side reduce_indptr is
    # acceptable since build is allowed to incur an occasional sync.
    num_partial_tiles = int(self.fp8_ps_reduce_indptr[-1].item())

    # Attach PS metadata to the metadata object so forward_mha can read it.
    metadata.fp8_prefill_qo_indptr = qo_indptr
    metadata.fp8_prefill_kv_indptr = kv_indptr
    metadata.fp8_prefill_kv_indices = kv_indices
    metadata.fp8_prefill_work_indptr = self.fp8_ps_work_indptr
    metadata.fp8_prefill_work_info_set = self.fp8_ps_work_info
    metadata.fp8_prefill_reduce_indptr = self.fp8_ps_reduce_indptr
    metadata.fp8_prefill_reduce_final_map = self.fp8_ps_reduce_final_map
    metadata.fp8_prefill_reduce_partial_map = self.fp8_ps_reduce_partial_map
    metadata.fp8_prefill_max_q_len = prefill.max_query_len
    metadata.fp8_prefill_num_partial_tiles = num_partial_tiles

_init_fp8_prefill_ps_buffers

_init_fp8_prefill_ps_buffers(
    max_num_reqs: int, max_prefill_qlen: int, device: device
) -> None

Pre-allocate persistent buffers for FP8 MLA prefill PS metadata.

Uses get_ps_metadata_info_v1 with max values so the buffers are large enough for any batch. get_ps_metadata_v1 fills them per-batch in build().

Parameters:

Name Type Description Default
max_num_reqs int

Maximum number of concurrent requests.

required
max_prefill_qlen int

Maximum Q-length for a single request in one prefill batch. Should be min(max_model_len, max_num_batched_tokens) — the chunked-prefill scheduler never emits more than max_num_batched_tokens new tokens per batch.

required
device device

Target device for the buffers.

required
Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla.py
def _init_fp8_prefill_ps_buffers(
    self,
    max_num_reqs: int,
    max_prefill_qlen: int,
    device: torch.device,
) -> None:
    """Pre-allocate persistent buffers for FP8 MLA prefill PS metadata.

    Uses ``get_ps_metadata_info_v1`` with max values so the buffers are
    large enough for any batch.  ``get_ps_metadata_v1`` fills them
    per-batch in ``build()``.

    Args:
        max_num_reqs: Maximum number of concurrent requests.
        max_prefill_qlen: Maximum Q-length for a single request in one
            prefill batch.  Should be ``min(max_model_len,
            max_num_batched_tokens)`` — the chunked-prefill scheduler
            never emits more than ``max_num_batched_tokens`` new tokens
            per batch.
        device: Target device for the buffers.
    """
    from aiter import get_ps_metadata_info_v1

    # After kv_b_proj decompression, K has num_heads heads (same as Q).
    # So gqa_ratio=1 and num_head_k=num_heads for the PS kernel.
    num_head_k = self.num_heads
    # gqa_ratio = 1
    # qlen_granularity = _FP8_PREFILL_TILE_Q // max(gqa_ratio, 1)
    qlen_granularity = _FP8_PREFILL_TILE_Q

    (
        (work_metadata_size, work_metadata_dtype),
        (work_indptr_size, work_indptr_dtype),
        (work_info_size, work_info_dtype),
        (reduce_indptr_size, reduce_indptr_dtype),
        (reduce_final_map_size, reduce_final_map_dtype),
        (reduce_partial_map_size, reduce_partial_map_dtype),
    ) = get_ps_metadata_info_v1(
        batch_size=max_num_reqs,
        num_head_k=num_head_k,
        max_qlen=max_prefill_qlen,
        qlen_granularity=qlen_granularity,
    )

    self.fp8_ps_work_metadata = torch.empty(
        work_metadata_size, dtype=work_metadata_dtype, device=device
    )
    self.fp8_ps_work_indptr = torch.empty(
        work_indptr_size, dtype=work_indptr_dtype, device=device
    )
    self.fp8_ps_work_info = torch.empty(
        *work_info_size, dtype=work_info_dtype, device=device
    )
    self.fp8_ps_reduce_indptr = torch.empty(
        reduce_indptr_size, dtype=reduce_indptr_dtype, device=device
    )
    self.fp8_ps_reduce_final_map = torch.empty(
        *reduce_final_map_size, dtype=reduce_final_map_dtype, device=device
    )
    self.fp8_ps_reduce_partial_map = torch.empty(
        reduce_partial_map_size,
        dtype=reduce_partial_map_dtype,
        device=device,
    )

    logger.info(
        "FP8 MLA prefill PS buffers allocated "
        "(max_batch=%d, max_qlen=%d, num_head_k=%d)",
        max_num_reqs,
        max_prefill_qlen,
        num_head_k,
    )

_expand_page_indices_kernel

_expand_page_indices_kernel(
    page_indices,
    block_table,
    block_table_stride,
    cu_num_tokens,
    seq_lens,
    KERNEL_BLOCK_SIZE: constexpr,
    BLOCK_SIZE: constexpr,
)

Expand block table entries into per-token flat page indices.

The aiter MLA kernel always operates with page_size=1 internally (kv_buffer is flattened via .view(-1, 1, 1, H)). This kernel converts block-level indices from the block table into individual token positions in the flattened KV buffer.

When KERNEL_BLOCK_SIZE=1: block_idx=t, offset=0, flat=block_id (equivalent to a direct copy -- no regression from the original kernel).

When KERNEL_BLOCK_SIZE=K: block table entry b (covering K tokens) is expanded to flat indices bK, bK+1, ..., b*K+(K-1).

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla.py
@triton.jit
def _expand_page_indices_kernel(
    page_indices,
    block_table,
    block_table_stride,
    cu_num_tokens,
    seq_lens,
    KERNEL_BLOCK_SIZE: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    """Expand block table entries into per-token flat page indices.

    The aiter MLA kernel always operates with page_size=1 internally
    (kv_buffer is flattened via .view(-1, 1, 1, H)). This kernel converts
    block-level indices from the block table into individual token positions
    in the flattened KV buffer.

    When KERNEL_BLOCK_SIZE=1: block_idx=t, offset=0, flat=block_id
    (equivalent to a direct copy -- no regression from the original kernel).

    When KERNEL_BLOCK_SIZE=K: block table entry b (covering K tokens)
    is expanded to flat indices b*K, b*K+1, ..., b*K+(K-1).
    """
    req_idx = tl.program_id(0)
    row_ptr = block_table + req_idx * block_table_stride
    start_idx = tl.load(cu_num_tokens + req_idx)
    num_tokens = tl.load(seq_lens + req_idx)

    offset = tl.arange(0, BLOCK_SIZE)
    for i in tl.range(0, num_tokens, BLOCK_SIZE):
        token_offsets = i + offset
        mask = token_offsets < num_tokens

        # Which block in the block table does this token belong to?
        block_idx = token_offsets // KERNEL_BLOCK_SIZE
        # Offset within that block
        offset_in_block = token_offsets % KERNEL_BLOCK_SIZE

        # Load the block ID from the block table
        block_ids = tl.load(row_ptr + block_idx, mask=mask)

        # Compute flat index in the flattened kv_buffer
        flat_indices = block_ids * KERNEL_BLOCK_SIZE + offset_in_block

        tl.store(
            page_indices + start_idx + token_offsets,
            flat_indices,
            mask=mask,
        )

_fp8_mla_prefill_supported cached

_fp8_mla_prefill_supported() -> bool

Auto-detect FP8 MLA prefill via mla_prefill_ps_asm_fwd + mla_reduce_v1.

Requires gfx950 plus an AITER build that exports both kernels. When either is missing we silently fall back to flash_attn_varlen_func.

Source code in vllm/v1/attention/backends/mla/rocm_aiter_mla.py
@functools.lru_cache(maxsize=1)
def _fp8_mla_prefill_supported() -> bool:
    """Auto-detect FP8 MLA prefill via mla_prefill_ps_asm_fwd + mla_reduce_v1.

    Requires gfx950 plus an AITER build that exports both kernels.  When
    either is missing we silently fall back to ``flash_attn_varlen_func``.
    """
    try:
        from vllm.platforms.rocm import on_gfx950
    except Exception:  # noqa: BLE001
        return False
    if not on_gfx950():
        return False
    try:
        from aiter import mla_prefill_ps_asm_fwd, mla_reduce_v1  # noqa: F401
    except Exception:  # noqa: BLE001
        return False
    return True