vllm.model_executor.kernels.mhc.aiter ¶
mhc_pre_aiter ¶
mhc_pre_aiter(
residual: Tensor,
fn: Tensor,
hc_scale: Tensor,
hc_base: Tensor,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 1,
) -> tuple[Tensor, Tensor, Tensor]
Forward pass for mHC pre block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
residual | Tensor | shape (..., hc_mult, hidden_size), dtype torch.bfloat16 | required |
fn | Tensor | shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32 | required |
hc_scale | Tensor | shape (3,), dtype torch.float32 | required |
hc_base | Tensor | shape (hc_mult3,), dtype torch.float32 | required |
rms_eps | float | RMS normalization epsilon | required |
hc_pre_eps | float | pre-mix epsilon | required |
hc_sinkhorn_eps | float | sinkhorn epsilon | required |
hc_post_mult_value | float | post-mix multiplier value | required |
sinkhorn_repeat | int | number of sinkhorn iterations | required |
n_splits | int | split-k factor; | 1 |
Returns:
| Name | Type | Description |
|---|---|---|
post_mix | Tensor | shape (..., hc_mult), dtype torch.float32 |
comb_mix | Tensor | shape (..., hc_mult, hc_mult), dtype torch.float32 |
layer_input | Tensor | shape (..., hidden_size), dtype torch.bfloat16 |