Skip to content

vllm.v1.attention.backends.mla.prefill.tokenspeed_mla

TokenSpeed CuTe DSL backend for MLA prefill.

TokenspeedMLAPrefillBackend

Bases: MLAPrefillBackend

TokenSpeed CuTe DSL backend for MLA prefill.

Source code in vllm/v1/attention/backends/mla/prefill/tokenspeed_mla.py
class TokenspeedMLAPrefillBackend(MLAPrefillBackend):
    """TokenSpeed CuTe DSL backend for MLA prefill."""

    requires_r1_mla_dimensions = True

    @staticmethod
    def get_name() -> str:
        return "TOKENSPEED_MLA"

    @classmethod
    def supports_compute_capability(cls, device_capability: "DeviceCapability") -> bool:
        return device_capability.major == 10

    _INSTALL_HINT = (
        "tokenspeed_mla package is not installed. "
        "Install it with: `uv pip install tokenspeed-mla`"
    )

    @classmethod
    def is_available(cls) -> bool:
        try:
            from tokenspeed_mla import (
                tokenspeed_mla_prefill,  # noqa: F401
            )

            return True
        except ImportError:
            return False

    @classmethod
    def validate_configuration(
        cls,
        device_capability,
        selector_config,
    ) -> list[str]:
        # Replace the generic "required dependencies not available" message
        # from the base class with a specific install hint so users know
        # exactly which package to install when they explicitly select this
        # backend without having tokenspeed_mla installed.
        reasons = super().validate_configuration(device_capability, selector_config)
        return [
            cls._INSTALL_HINT if r == "required dependencies not available" else r
            for r in reasons
        ]

    def __init__(
        self,
        num_heads: int,
        scale: float,
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        vllm_config: "VllmConfig",
    ) -> None:
        super().__init__(
            num_heads=num_heads,
            scale=scale,
            kv_lora_rank=kv_lora_rank,
            qk_nope_head_dim=qk_nope_head_dim,
            qk_rope_head_dim=qk_rope_head_dim,
            v_head_dim=v_head_dim,
            vllm_config=vllm_config,
        )

        # Pre-JIT BF16 and FP8 prefill kernels. Idempotent — also called from
        # TokenspeedMLAImpl.__init__; second call is a no-op.
        from tokenspeed_mla import warmup_compile_prefill

        for q_dtype in (torch.bfloat16, torch.float8_e4m3fn):
            warmup_compile_prefill(
                q_dtype=q_dtype,
                d_qk=qk_nope_head_dim + qk_rope_head_dim,
                d_v=v_head_dim,
                enable_pdl=False,
            )

    def prepare_metadata(
        self,
        prefill_metadata: "MLACommonPrefillMetadata",
    ) -> None:
        super().prepare_metadata(prefill_metadata)
        # Kernel signature requires `seq_lens` but the implementation never reads
        # it (per-batch lengths are derived from `cum_seq_lens` diffs); compute
        # for parity with trtllm_ragged. cuda-graph padding in
        # `query_start_loc` is saturated to `total_num_tokens`
        # (gpu_model_runner.py:1905), so trailing diffs are 0 and padded batches
        # are kernel no-ops — same reason trtllm passes the padded length as
        # batch_size directly.
        self._query_seq_lens = (
            prefill_metadata.query_start_loc[1:] - prefill_metadata.query_start_loc[:-1]
        )

    def run_prefill_new_tokens(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        return_softmax_lse: bool,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        from tokenspeed_mla import tokenspeed_mla_prefill

        # `v` arrives as the second half of `kv_nope.split(...)` in
        # mla_attention.forward_mha — a non-contiguous view of `kv_nope` along
        # dim=-1. The kernel does `v.reshape(1, total_kv, h_k, 1, d_v)` which
        # would silently copy on a non-contiguous tensor; force contiguity here
        # so the copy (if any) happens once outside the kernel call.
        v = v.contiguous()

        ret = tokenspeed_mla_prefill(
            query=q,
            key=k,
            value=v,
            seq_lens=self._query_seq_lens,
            cum_seq_lens=self._prefill_metadata.query_start_loc,
            max_seq_len=self._prefill_metadata.max_query_len,
            batch_size=self._query_seq_lens.shape[0],
            softmax_scale=self.scale,
            is_causal=True,
            return_lse=return_softmax_lse,
            enable_pdl=False,
        )

        if isinstance(ret, tuple):
            # Convert from (q_len, num_heads) to (num_heads, q_len)
            return ret[0], ret[1].transpose(0, 1).contiguous()
        return ret

    def run_prefill_context_chunk(
        self,
        chunk_idx: int,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        from tokenspeed_mla import tokenspeed_mla_prefill

        assert self._prefill_metadata.chunked_context is not None
        chunked = self._prefill_metadata.chunked_context

        # See note in run_prefill_new_tokens — `v` is a split-view of `kv_nope`
        # in `_compute_prefill_context` and arrives non-contiguous.
        v = v.contiguous()

        attn_out, lse = tokenspeed_mla_prefill(
            query=q,
            key=k,
            value=v,
            seq_lens=chunked.seq_lens[chunk_idx],
            cum_seq_lens=chunked.cu_seq_lens[chunk_idx],
            max_seq_len=chunked.max_seq_lens[chunk_idx],
            batch_size=chunked.seq_lens[chunk_idx].shape[0],
            softmax_scale=self.scale,
            is_causal=False,
            return_lse=True,
            cum_seq_lens_q=self._prefill_metadata.query_start_loc,
            max_seq_len_q=self._prefill_metadata.max_query_len,
            enable_pdl=False,
        )

        # Convert from (q_len, num_heads) to (num_heads, q_len)
        return attn_out, lse.transpose(0, 1).contiguous()