Skip to content

vllm.model_executor.kernels.mhc.triton

_hc_head_triton

_hc_head_triton(
    hs_flat: Tensor,
    fn: Tensor,
    hc_scale: Tensor,
    hc_base: Tensor,
    out: Tensor,
    hidden_size: int,
    rms_eps: float,
    hc_eps: float,
    hc_mult: int,
) -> None

Fill pre-allocated out (T, H) in-place with the hc_head result.

Source code in vllm/model_executor/kernels/mhc/triton.py
def _hc_head_triton(
    hs_flat: torch.Tensor,
    fn: torch.Tensor,
    hc_scale: torch.Tensor,
    hc_base: torch.Tensor,
    out: torch.Tensor,
    hidden_size: int,
    rms_eps: float,
    hc_eps: float,
    hc_mult: int,
) -> None:
    """Fill pre-allocated `out` (T, H) in-place with the hc_head result."""
    if hs_flat.shape[0] == 0:
        return

    hc_head_reduce_triton_kernel(
        hs_flat,
        fn,
        hc_scale,
        hc_base,
        out,
        rms_eps,
        hc_eps,
    )
    return

_rmsnorm_nw_kernel

_rmsnorm_nw_kernel(
    x_ptr, out_ptr, stride_row, D, eps, RBLOCK: constexpr
)

Weight-free RMSNorm Triton kernel: out = x * rsqrt(mean(x², -1) + eps).

Source code in vllm/model_executor/kernels/mhc/triton.py
@triton.jit
def _rmsnorm_nw_kernel(
    x_ptr,
    out_ptr,
    stride_row,
    D,
    eps,
    RBLOCK: tl.constexpr,
):
    """Weight-free RMSNorm Triton kernel: out = x * rsqrt(mean(x², -1) + eps)."""
    row = tl.program_id(0)
    cols = tl.arange(0, RBLOCK)
    mask = cols < D

    x = tl.load(
        x_ptr + row * stride_row + cols,
        mask=mask,
        other=0.0,
        eviction_policy="evict_first",
    ).to(tl.float32)

    var = tl.sum(x * x, 0) / D
    rstd = tl.rsqrt(var + eps)

    out = (x * rstd).to(out_ptr.dtype.element_ty)
    tl.store(out_ptr + row * D + cols, out, mask=mask, eviction_policy="evict_first")

rmsnorm_nw

rmsnorm_nw(x: Tensor, eps: float) -> Tensor

Weight-free RMSNorm over the last dimension.

Treats x as [num_rows, D] where num_rows = product(shape[:-1]). Returns a contiguous tensor with the same shape and dtype as x.

Source code in vllm/model_executor/kernels/mhc/triton.py
def rmsnorm_nw(x: Tensor, eps: float) -> Tensor:
    """Weight-free RMSNorm over the last dimension.

    Treats *x* as ``[num_rows, D]`` where ``num_rows = product(shape[:-1])``.
    Returns a contiguous tensor with the same shape and dtype as *x*.
    """
    orig_shape = x.shape
    D = orig_shape[-1]
    x_2d = x.reshape(-1, D)
    num_rows = x_2d.shape[0]

    out = torch.empty_like(x_2d)
    RBLOCK = triton.next_power_of_2(D)

    _rmsnorm_nw_kernel[(num_rows,)](
        x_2d,
        out,
        x_2d.stride(0),
        D,
        eps,
        RBLOCK=RBLOCK,
        num_warps=1 if RBLOCK <= 512 else (4 if RBLOCK <= 4096 else 8),
    )
    return out.view(orig_shape)