Skip to content

vllm.model_executor.layers.quantization.quark.schemes

Modules:

Name Description
quark_nvfp4
quark_scheme
quark_w4a8_mxfp4_fp8

QuarkNVFP4

Bases: QuarkScheme

Quark NVFP4 quantization scheme.

Supports loading NVFP4 checkpoints with the following structure: - weight: uint8, shape [out_features, in_features // 2] (packed FP4) - weight_scale: float8_e4m3fn, shape [out_features, in_features // group_size] - weight_scale_2: bfloat16/float32, scalar (global weight scale) - input_scale_2: bfloat16/float32, scalar (global input scale)

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_nvfp4.py
class QuarkNVFP4(QuarkScheme):
    """
    Quark NVFP4 quantization scheme.

    Supports loading NVFP4 checkpoints with the following structure:
    - weight: uint8, shape [out_features, in_features // 2] (packed FP4)
    - weight_scale: float8_e4m3fn, shape [out_features, in_features // group_size]
    - weight_scale_2: bfloat16/float32, scalar (global weight scale)
    - input_scale_2: bfloat16/float32, scalar (global input scale)
    """

    def __init__(
        self,
    ):
        self.kernel = init_nvfp4_linear_kernel()
        self.group_size = 16

        if not isinstance(self.kernel, EmulationNvFp4LinearKernel):
            logger.warning_once(
                "Only EmulationNvFp4LinearKernel NVFP4 dense implementation is "
                "tested with QuarkNVFP4, got kernel=%s. Correctness is not validated.",
                type(self.kernel).__name__,
            )

    @classmethod
    def get_min_capability(cls) -> int:
        # FP4 requires Turing (75) or newer
        return 75

    def create_weights(
        self,
        layer: torch.nn.Module,
        output_partition_sizes: list[int],
        input_size_per_partition: int,
        params_dtype: torch.dtype,
        weight_loader: Callable,
        **kwargs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        if input_size_per_partition % self.group_size != 0:
            raise ValueError(
                f"Input size per partition ({input_size_per_partition}) must be "
                f"divisible by group size ({self.group_size})"
            )

        # Weight: FP4 packed as uint8 (2 FP4 values per uint8)
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // 2,
                dtype=torch.uint8,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        # Per-group weight scale (FP8 E4M3)
        weight_scale = GroupQuantScaleParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.group_size,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight_scale", weight_scale)

        # Global weight scale (scalar, per partition)
        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Global input scale (scalar, per partition)
        input_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("input_scale_2", input_scale_2)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        input_global_scale = layer.input_scale_2.max().to(torch.float32)
        layer.input_global_scale = Parameter(input_global_scale, requires_grad=False)
        del layer.input_scale_2

        weight_global_scale = layer.weight_scale_2.to(torch.float32)

        if torch.unique(weight_global_scale).numel() != 1:
            logger.warning_once(
                "In NVFP4 linear, the global scale for weight are different"
                " for parallel layers (e.g. q_proj, k_proj, v_proj). This"
                " will likely result in reduced accuracy. Please verify the"
                " model accuracy. Consider using a checkpoint with a shared"
                " global NVFP4 scale for fused layers."
            )

        weight_global_scale = weight_global_scale.max()

        layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False)
        del layer.weight_scale_2

        layer.alpha = Parameter(
            layer.input_global_scale * layer.weight_global_scale, requires_grad=False
        )
        layer.input_global_scale_inv = Parameter(
            (1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False
        )

        # Convert layer to NVFP4 linear kernel format
        self.kernel.process_weights_after_loading(layer)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.kernel.apply_weights(layer=layer, x=x, bias=bias)

QuarkScheme

Bases: ABC

Abstract class used to describe the weight creation and forward pass of different quantization schemes supported by Quark.

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py
class QuarkScheme(ABC):
    """
    Abstract class used to describe the weight creation and forward pass
    of different quantization schemes supported by Quark.
    """

    @classmethod
    @abstractmethod
    def get_min_capability(cls) -> int:
        """
        Get minimum device capability.
        """
        raise NotImplementedError

    @abstractmethod
    def create_weights(self, *args, **kwargs):
        """
        Weight creation for the particular scheme. Inputs to this function

        """
        raise NotImplementedError

    @abstractmethod
    def apply_weights(
        self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
    ):
        """
        Run the forward pass for the particular scheme. This is where
        scheme-specific dequant/quant steps/kernels should be applied.

        :param layer: torch.nn.Module with the registered weights and
            other parameters relevant to the particular scheme.
        :param x: input to the layer
        :param bias: bias parameter

        """
        raise NotImplementedError

    @abstractmethod
    def process_weights_after_loading(self, layer: torch.nn.Module):
        """
        Called after weight loading is complete for any cleanup that
        needs to occur.
        """
        raise NotImplementedError

apply_weights abstractmethod

apply_weights(
    layer: Module, x: Tensor, bias: Tensor | None
)

Run the forward pass for the particular scheme. This is where scheme-specific dequant/quant steps/kernels should be applied.

:param layer: torch.nn.Module with the registered weights and other parameters relevant to the particular scheme. :param x: input to the layer :param bias: bias parameter

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py
@abstractmethod
def apply_weights(
    self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
):
    """
    Run the forward pass for the particular scheme. This is where
    scheme-specific dequant/quant steps/kernels should be applied.

    :param layer: torch.nn.Module with the registered weights and
        other parameters relevant to the particular scheme.
    :param x: input to the layer
    :param bias: bias parameter

    """
    raise NotImplementedError

create_weights abstractmethod

create_weights(*args, **kwargs)

Weight creation for the particular scheme. Inputs to this function

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py
@abstractmethod
def create_weights(self, *args, **kwargs):
    """
    Weight creation for the particular scheme. Inputs to this function

    """
    raise NotImplementedError

get_min_capability abstractmethod classmethod

get_min_capability() -> int

Get minimum device capability.

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
    """
    Get minimum device capability.
    """
    raise NotImplementedError

process_weights_after_loading abstractmethod

process_weights_after_loading(layer: Module)

Called after weight loading is complete for any cleanup that needs to occur.

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
    """
    Called after weight loading is complete for any cleanup that
    needs to occur.
    """
    raise NotImplementedError

QuarkW4A8_MXFP4_FP8

Bases: QuarkScheme

  • Weights: MXFP4 with E8M0 scales per block of 32
  • Activations: FP8 E4M3 (static per-tensor quantization)

Uses the AITER Triton kernel and falls back to emulation if AITER not available.

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w4a8_mxfp4_fp8.py
class QuarkW4A8_MXFP4_FP8(QuarkScheme):
    """
    - Weights: MXFP4 with E8M0 scales per block of 32
    - Activations: FP8 E4M3 (static per-tensor quantization)

    Uses the AITER Triton kernel and falls back to emulation if AITER not available.
    """

    def __init__(
        self,
        weight_quant_spec: dict[str, Any],
        input_quant_spec: dict[str, Any],
    ):
        self.out_dtype = None

        self.weight_dtype = "mxfp4"
        self.packed_factor: Fraction = Fraction(2, 1)  # 2 FP4 values per byte
        self.weight_block_size = OCP_MX_BLOCK_SIZE

        self.is_static_input_scheme = not input_quant_spec.get("is_dynamic")
        self.input_qscheme = input_quant_spec.get("qscheme")  # "per_tensor"

        self.fp8_min, self.fp8_max = get_fp8_min_max()
        self.fp8_dtype = current_platform.fp8_dtype()

        if not self.is_static_input_scheme:
            raise NotImplementedError(
                "Dynamic FP8 activation quantization is not yet supported "
                "for W4A8. The current implementation expects static per-tensor "
                "FP8 scales stored in the checkpoint."
            )

        kernel_supported_gpu = False
        if current_platform.is_rocm():
            from vllm.platforms.rocm import on_gfx950

            kernel_supported_gpu = on_gfx950()

        self.use_aiter_kernel = (
            is_aiter_found_and_supported()
            and self.is_static_input_scheme
            and kernel_supported_gpu
        )

        if not self.use_aiter_kernel:
            logger.warning_once(
                "[W4A8 MXFP4+FP8] Aiter Triton kernel not found. Using emulation mode."
            )

    @classmethod
    def get_min_capability(cls) -> int:
        return 70

    def get_packed_dim(self, dim: int) -> int:
        assert dim % 2 == 0, f"Dimension {dim} must be even for MXFP4 packing"
        return dim // 2

    def create_weights(
        self,
        layer: torch.nn.Module,
        output_partition_sizes: list[int],
        input_size_per_partition: int,
        params_dtype: torch.dtype,
        weight_loader: Callable,
        **kwargs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # MXFP4 WEIGHT (packed, 2 values per byte)
        weight = PackedvLLMParameter(
            data=torch.empty(
                output_size_per_partition,
                self.get_packed_dim(input_size_per_partition),
                dtype=torch.uint8,
            ),
            input_dim=1,
            output_dim=0,
            packed_dim=1,
            packed_factor=self.packed_factor,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE (E8M0 format, per block of 32)
        weight_scale = GroupQuantScaleParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.weight_block_size,
                dtype=torch.uint8,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight_scale", weight_scale)

        # INPUT SCALE (FP8 per-tensor static scale)
        if self.is_static_input_scheme:
            input_scale = PerTensorScaleParameter(
                data=torch.empty(
                    len(output_partition_sizes),
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            # Initialize to avoid NaN
            input_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # Ensuring weights & scales are non-trainable
        layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
        layer.weight_scale = torch.nn.Parameter(
            layer.weight_scale.data, requires_grad=False
        )

        if self.is_static_input_scheme:
            input_scale = layer.input_scale.data
            # For fused modules (QKV), take the max scale
            if input_scale.numel() != 1:
                input_scale = input_scale.max()

            layer.input_scale = torch.nn.Parameter(
                torch.tensor(input_scale, dtype=torch.float32),
                requires_grad=False,
            )

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if self.use_aiter_kernel:
            return self._apply_aiter_kernel(layer, x, bias)
        else:
            return self._apply_emulation(layer, x, bias)

    def _apply_aiter_kernel(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        M = x.shape[0]
        out_dtype = x.dtype if self.out_dtype is None else self.out_dtype

        input_scale = layer.input_scale
        x_fp8 = (x / input_scale).clamp(self.fp8_min, self.fp8_max).to(self.fp8_dtype)

        # Broadcast per-tensor scale to per-row (M, 1) for Aiter kernel
        x_scales = input_scale.expand(M, 1).to(dtype=torch.float32, device=x.device)

        y = rocm_aiter_ops.gemm_a8wfp4(
            x_fp8, layer.weight, x_scales, layer.weight_scale, out_dtype
        )

        if bias is not None:
            y = y + bias

        return y

    def _apply_emulation(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
            dequant_mxfp4,
        )

        weight_dq = dequant_mxfp4(
            layer.weight,
            layer.weight_scale,
            x.dtype,
        )

        input_scale = layer.input_scale
        x_fp8 = (x / input_scale).clamp(self.fp8_min, self.fp8_max).to(self.fp8_dtype)
        x_dq = (x_fp8.to(x.dtype) * input_scale).to(x.dtype)

        return F.linear(x_dq, weight_dq, bias)