Skip to content

vllm.v1.attention.ops.triton_unified_attention

_cast_kv_tile

_cast_kv_tile(
    data, Q, tensor_scale, KV_QUANT_MODE: constexpr
)

Cast a loaded KV tile to Q's dtype, dequantizing if needed.

Modes handled inside the core kernel:

  • KV_QUANT_MODE == 0 (NONE) and 2 (INT8 per-token-head) and 3 (FP8 per-token-head): plain cast. Per-token-head modes apply their scales separately on S/P inside the loop.
  • KV_QUANT_MODE == 1 (FP8 per-tensor): dequantize using the tensor-wide scale, unless Q is also FP8 and the caller folds the scales into the attention score and output accumulator.
Source code in vllm/v1/attention/ops/triton_unified_attention.py
@triton.jit
def _cast_kv_tile(data, Q, tensor_scale, KV_QUANT_MODE: tl.constexpr):
    """Cast a loaded KV tile to Q's dtype, dequantizing if needed.

    Modes handled inside the core kernel:

    - ``KV_QUANT_MODE == 0`` (NONE) and ``2`` (INT8 per-token-head) and
      ``3`` (FP8 per-token-head): plain cast.  Per-token-head modes apply
      their scales separately on S/P inside the loop.
    - ``KV_QUANT_MODE == 1`` (FP8 per-tensor): dequantize using the
      tensor-wide scale, unless Q is also FP8 and the caller folds the scales
      into the attention score and output accumulator.
    """
    if KV_QUANT_MODE == 1:
        if Q.dtype.is_fp8():
            return data.to(Q.dtype)
        return (data.to(tl.float32) * tl.load(tensor_scale)).to(Q.dtype)
    return data.to(Q.dtype)

_get_tile_size

_get_tile_size(
    head_size: int,
    sliding_window: int,
    element_size: int,
    is_prefill: bool,
) -> int

Select tile size with Gemma3-specific optimization.

Source code in vllm/v1/attention/ops/triton_unified_attention.py
def _get_tile_size(
    head_size: int,
    sliding_window: int,
    element_size: int,
    is_prefill: bool,
) -> int:
    """Select tile size with Gemma3-specific optimization."""
    if _is_gemma3_attention(head_size, sliding_window):
        # Gemma3: use 32 for decode (default is 16)
        return 32

    # Default behavior
    if is_prefill:
        return 32
    # Note: tile size must be at least 32 for fp8 (element_size == 1).
    return 16 if element_size >= 2 else 32

_is_gemma3_attention

_is_gemma3_attention(
    head_size: int, sliding_window: int
) -> bool

Detect Gemma3 models via unique (head_size, sliding_window) signature.

Gemma3 models are the only ones using sliding_window=1024 with head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use different window sizes (Mistral=4096, Phi-3=2047).

Source code in vllm/v1/attention/ops/triton_unified_attention.py
def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool:
    """Detect Gemma3 models via unique (head_size, sliding_window) signature.

    Gemma3 models are the only ones using sliding_window=1024 with
    head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use
    different window sizes (Mistral=4096, Phi-3=2047).
    """
    return sliding_window == 1024 and head_size in (128, 256)

_load_kv_tile_td

_load_kv_tile_td(
    cache_ptr,
    physical_block_idx_scalar,
    kv_head_idx,
    offset_in_block,
    stride_cache_0: int64,
    stride_cache_1: int64,
    stride_cache_2: int64,
    stride_cache_3: constexpr,
    BLOCK_SIZE: constexpr,
    TILE_SIZE: constexpr,
    HEAD_SIZE: constexpr,
    HEAD_SIZE_PADDED: constexpr,
)

Load a KV cache tile via tensor descriptor.

Returns shape (TILE_SIZE, HEAD_SIZE_PADDED). Caller transposes for K. Tensor descriptors zero-pad reads beyond the shape boundary, so HEAD_SIZE_PADDED > HEAD_SIZE is handled correctly.

Source code in vllm/v1/attention/ops/triton_unified_attention.py
@triton.jit
def _load_kv_tile_td(
    cache_ptr,
    physical_block_idx_scalar,
    kv_head_idx,
    offset_in_block,
    stride_cache_0: tl.int64,
    stride_cache_1: tl.int64,
    stride_cache_2: tl.int64,
    stride_cache_3: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    TILE_SIZE: tl.constexpr,
    HEAD_SIZE: tl.constexpr,
    HEAD_SIZE_PADDED: tl.constexpr,
):
    """Load a KV cache tile via tensor descriptor.

    Returns shape (TILE_SIZE, HEAD_SIZE_PADDED). Caller transposes for K.
    Tensor descriptors zero-pad reads beyond the shape boundary, so
    ``HEAD_SIZE_PADDED > HEAD_SIZE`` is handled correctly.
    """
    base = (
        cache_ptr
        + physical_block_idx_scalar * stride_cache_0
        + kv_head_idx * stride_cache_2
    )
    desc = tl.make_tensor_descriptor(
        base=base,
        shape=(BLOCK_SIZE, HEAD_SIZE),
        strides=(stride_cache_1, stride_cache_3),
        block_shape=(TILE_SIZE, HEAD_SIZE_PADDED),
    )
    return desc.load([offset_in_block, 0])

_load_q_td

_load_q_td(
    query_ptr,
    q_block_local_len,
    query_stride_0: int64,
    query_stride_1: int64,
    cur_batch_in_all_start_index,
    q_block_local_idx,
    kv_head_idx,
    num_queries_per_kv: constexpr,
    BLOCK_Q: constexpr,
    BLOCK_M: constexpr,
    HEAD_SIZE: constexpr,
    HEAD_SIZE_PADDED: constexpr,
)

Load Q via a 2D tensor descriptor.

Caller guarantees (via the wrapper's use_td_qo gate): * HEAD_SIZE == HEAD_SIZE_PADDED (head_size is a power of 2), * num_queries_per_kv is a power of 2, * the num_queries_per_kv heads of the current KV group are contiguous in memory (query_stride_1 == HEAD_SIZE, which is the default vLLM query layout).

Under those preconditions the inner two axes are flattened into one row of size num_queries_per_kv * HEAD_SIZE with stride 1, which avoids the non-power-of-2 block_shape error from the Triton tensor-descriptor validator. Returns (BLOCK_M, HEAD_SIZE_PADDED).

Source code in vllm/v1/attention/ops/triton_unified_attention.py
@triton.jit
def _load_q_td(
    query_ptr,
    q_block_local_len,
    query_stride_0: tl.int64,
    query_stride_1: tl.int64,
    cur_batch_in_all_start_index,
    q_block_local_idx,
    kv_head_idx,
    num_queries_per_kv: tl.constexpr,
    BLOCK_Q: tl.constexpr,
    BLOCK_M: tl.constexpr,
    HEAD_SIZE: tl.constexpr,
    HEAD_SIZE_PADDED: tl.constexpr,
):
    """Load Q via a 2D tensor descriptor.

    Caller guarantees (via the wrapper's ``use_td_qo`` gate):
      * ``HEAD_SIZE == HEAD_SIZE_PADDED`` (head_size is a power of 2),
      * ``num_queries_per_kv`` is a power of 2,
      * the ``num_queries_per_kv`` heads of the current KV group are
        contiguous in memory (``query_stride_1 == HEAD_SIZE``, which is
        the default vLLM query layout).

    Under those preconditions the inner two axes are flattened into one
    row of size ``num_queries_per_kv * HEAD_SIZE`` with stride 1, which
    avoids the non-power-of-2 ``block_shape`` error from the Triton
    tensor-descriptor validator.  Returns (BLOCK_M, HEAD_SIZE_PADDED).
    """
    q_base = (
        query_ptr
        + (cur_batch_in_all_start_index + q_block_local_idx * BLOCK_Q) * query_stride_0
        + (kv_head_idx * num_queries_per_kv) * query_stride_1
    )
    q_desc = tl.make_tensor_descriptor(
        base=q_base,
        shape=(q_block_local_len, num_queries_per_kv * HEAD_SIZE),
        strides=(query_stride_0, 1),
        block_shape=(BLOCK_Q, num_queries_per_kv * HEAD_SIZE_PADDED),
    )
    return q_desc.load([0, 0]).reshape(BLOCK_M, HEAD_SIZE_PADDED)

_store_output_td

_store_output_td(
    base_ptr,
    acc,
    q_block_local_len,
    stride_token: int64,
    stride_head: int64,
    num_queries_per_kv: constexpr,
    BLOCK_Q: constexpr,
    HEAD_SIZE: constexpr,
    HEAD_SIZE_PADDED: constexpr,
)

Store an output tile via a tensor descriptor.

The 2D and 3D epilogues differ only in base_ptr and the (stride_token, stride_head) pair: 2D writes directly to the flat output buffer, 3D writes to a single per-segment slice of segm_output_ptr. Descriptor shape / block_shape / reshape are the same in both modes, so share one helper.

Source code in vllm/v1/attention/ops/triton_unified_attention.py
@triton.jit
def _store_output_td(
    base_ptr,
    acc,
    q_block_local_len,
    stride_token: tl.int64,
    stride_head: tl.int64,
    num_queries_per_kv: tl.constexpr,
    BLOCK_Q: tl.constexpr,
    HEAD_SIZE: tl.constexpr,
    HEAD_SIZE_PADDED: tl.constexpr,
):
    """Store an output tile via a tensor descriptor.

    The 2D and 3D epilogues differ only in ``base_ptr`` and the
    ``(stride_token, stride_head)`` pair: 2D writes directly to the
    flat output buffer, 3D writes to a single per-segment slice of
    ``segm_output_ptr``.  Descriptor shape / block_shape / reshape
    are the same in both modes, so share one helper.
    """
    acc = acc.to(base_ptr.dtype.element_ty)
    output_desc = tl.make_tensor_descriptor(
        base=base_ptr,
        shape=(q_block_local_len, num_queries_per_kv, HEAD_SIZE),
        strides=(stride_token, stride_head, 1),
        block_shape=(BLOCK_Q, num_queries_per_kv, HEAD_SIZE_PADDED),
    )
    output_desc.store(
        [0, 0, 0],
        acc.reshape(BLOCK_Q, num_queries_per_kv, HEAD_SIZE_PADDED),
    )