Skip to content

vllm.model_executor.kernels.linear.mxfp4.xpu

XPUMxFp4LinearKernel

Bases: MxFp4LinearKernel

MXFP4 W4A4 GEMM on XPU.

Source code in vllm/model_executor/kernels/linear/mxfp4/xpu.py
class XPUMxFp4LinearKernel(MxFp4LinearKernel):
    """MXFP4 W4A4 GEMM on XPU."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if not current_platform.is_xpu():
            return False, "XPUMxFp4 only support on XPU"
        return True, None

    @classmethod
    def can_implement(cls, c: MxFp4LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        weight = layer.weight.view(torch.float4_e2m1fn_x2)
        replace_parameter(layer, "weight", weight.data.t())

        weight_scale = layer.weight_scale.view(torch.float8_e8m0fnu)
        weight_scale = weight_scale.t().contiguous()
        replace_parameter(layer, "weight_scale", weight_scale.data)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        out_dtype = x.dtype
        x_fp4, x_blockscale = quant_mxfp4(x)
        return torch.ops._xpu_C.fp4_gemm(
            x_fp4,
            layer.weight,
            x_blockscale,
            layer.weight_scale,
            out_dtype,
            bias,
        )