Skip to content

vllm.model_executor.kernels.mhc.aiter

mhc_pre_aiter

mhc_pre_aiter(
    residual: Tensor,
    fn: Tensor,
    hc_scale: Tensor,
    hc_base: Tensor,
    rms_eps: float,
    hc_pre_eps: float,
    hc_sinkhorn_eps: float,
    hc_post_mult_value: float,
    sinkhorn_repeat: int,
    n_splits: int = 1,
) -> tuple[Tensor, Tensor, Tensor]

Forward pass for mHC pre block.

Parameters:

Name Type Description Default
residual Tensor

shape (..., hc_mult, hidden_size), dtype torch.bfloat16

required
fn Tensor

shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32

required
hc_scale Tensor

shape (3,), dtype torch.float32

required
hc_base Tensor

shape (hc_mult3,), dtype torch.float32

required
rms_eps float

RMS normalization epsilon

required
hc_pre_eps float

pre-mix epsilon

required
hc_sinkhorn_eps float

sinkhorn epsilon

required
hc_post_mult_value float

post-mix multiplier value

required
sinkhorn_repeat int

number of sinkhorn iterations

required
n_splits int

split-k factor;

1

Returns:

Name Type Description
post_mix Tensor

shape (..., hc_mult), dtype torch.float32

comb_mix Tensor

shape (..., hc_mult, hc_mult), dtype torch.float32

layer_input Tensor

shape (..., hidden_size), dtype torch.bfloat16

Source code in vllm/model_executor/kernels/mhc/aiter.py
def mhc_pre_aiter(
    residual: torch.Tensor,
    fn: torch.Tensor,
    hc_scale: torch.Tensor,
    hc_base: torch.Tensor,
    rms_eps: float,
    hc_pre_eps: float,
    hc_sinkhorn_eps: float,
    hc_post_mult_value: float,
    sinkhorn_repeat: int,
    n_splits: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Forward pass for mHC pre block.

    Args:
        residual: shape (..., hc_mult, hidden_size), dtype torch.bfloat16
        fn: shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32
        hc_scale: shape (3,), dtype torch.float32
        hc_base: shape (hc_mult3,), dtype torch.float32
        rms_eps: RMS normalization epsilon
        hc_pre_eps: pre-mix epsilon
        hc_sinkhorn_eps: sinkhorn epsilon
        hc_post_mult_value: post-mix multiplier value
        sinkhorn_repeat: number of sinkhorn iterations
        n_splits: split-k factor;

    Returns:
        post_mix: shape (..., hc_mult), dtype torch.float32
        comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32
        layer_input: shape (..., hidden_size), dtype torch.bfloat16
    """

    hidden_size = residual.shape[-1]
    assert hidden_size % 256 == 0
    from vllm._aiter_ops import rocm_aiter_ops

    return rocm_aiter_ops.mhc_pre(
        residual,
        fn,
        hc_scale,
        hc_base,
        rms_eps,
        hc_pre_eps,
        hc_sinkhorn_eps,
        hc_post_mult_value,
        sinkhorn_repeat,
    )