Skip to content

vllm.model_executor.layers.quantization.turboquant.config

TurboQuant configuration.

TurboQuantConfig dataclass

Configuration for TurboQuant KV-cache quantization.

Applies Hadamard rotation followed by per-coordinate Lloyd-Max scalar quantization for keys, and uniform quantization for values.

Historical note: the core algorithmic pattern implemented for key quantization (Hadamard rotation followed by deterministic scalar quantization and re-normalization) was originally established in DRIVE (Vargaftik et al., NeurIPS 2021) and EDEN (Vargaftik et al., ICML 2022). This formulation is also mathematically equivalent to the scalar case of the HIGGS quantization method (Malinovskii et al., "Pushing the Limits of Large Language Model Quantization via the Linearity Theorem", NAACL 2025; preprint arXiv:2411.17525), which subsequently generalized these concepts.

A first application of this approach to KV-cache compression is in "Cache Me If You Must: Adaptive Key-Value Quantization for Large Language Models" (Shutova et al., ICML 2025; preprint arXiv:2501.19392). All of these foundational and application references pre-date the TurboQuant paper (Zandieh et al., ICLR 2026).

QJL is intentionally omitted: community consensus (5+ independent groups) found it hurts attention quality by amplifying variance through softmax.

Named presets (use via --kv-cache-dtype): turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL turboquant_4bit_nc: 4-bit MSE keys + 4-bit values + NC, 3.8x, +2.71% turboquant_k3v4_nc: 3-bit MSE keys + 4-bit values + NC, ~3.5x, +10.63% turboquant_3bit_nc: 3-bit MSE keys + 3-bit values + NC, 4.9x, +20.59%

Parameters:

Name Type Description Default
head_dim int

Attention head dimension (e.g. 64, 96, 128).

128
key_quant_bits int

Bits for key quantization. 8 = FP8 keys (no rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys.

3
value_quant_bits int

Bits per value dimension for uniform quantization. 3 = 8 levels, 4 = 16 levels (default).

4
norm_correction bool

Re-normalize centroid vectors to unit norm before inverse rotation during dequant. Fixes quantization-induced norm distortion, improving PPL by ~0.8% at 4-bit.

False
Source code in vllm/model_executor/layers/quantization/turboquant/config.py
@dataclass
class TurboQuantConfig:
    """Configuration for TurboQuant KV-cache quantization.

    Applies Hadamard rotation followed by per-coordinate Lloyd-Max scalar
    quantization for keys, and uniform quantization for values.

    Historical note: the core algorithmic pattern implemented for key
    quantization (Hadamard rotation followed by deterministic scalar
    quantization and re-normalization) was originally established in DRIVE
    (Vargaftik et al., NeurIPS 2021) and EDEN (Vargaftik et al., ICML
    2022). This formulation is also mathematically equivalent to the
    scalar case of the HIGGS quantization method (Malinovskii et al.,
    "Pushing the Limits of Large Language Model Quantization via the
    Linearity Theorem", NAACL 2025; preprint arXiv:2411.17525), which
    subsequently generalized these concepts.

    A first application of this approach to KV-cache compression is in
    "Cache Me If You Must: Adaptive Key-Value Quantization for Large
    Language Models" (Shutova et al., ICML 2025; preprint
    arXiv:2501.19392). All of these foundational and application
    references pre-date the TurboQuant paper (Zandieh et al., ICLR 2026).

    QJL is intentionally omitted: community consensus (5+ independent
    groups) found it hurts attention quality by amplifying variance
    through softmax.

    Named presets (use via --kv-cache-dtype):
        turboquant_k8v4:   FP8 keys + 4-bit values, 2.6x, +1.17% PPL
        turboquant_4bit_nc: 4-bit MSE keys + 4-bit values + NC, 3.8x, +2.71%
        turboquant_k3v4_nc: 3-bit MSE keys + 4-bit values + NC, ~3.5x, +10.63%
        turboquant_3bit_nc: 3-bit MSE keys + 3-bit values + NC, 4.9x, +20.59%

    Args:
        head_dim: Attention head dimension (e.g. 64, 96, 128).
        key_quant_bits: Bits for key quantization. 8 = FP8 keys (no
            rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys.
        value_quant_bits: Bits per value dimension for uniform quantization.
            3 = 8 levels, 4 = 16 levels (default).
        norm_correction: Re-normalize centroid vectors to unit norm before
            inverse rotation during dequant. Fixes quantization-induced norm
            distortion, improving PPL by ~0.8% at 4-bit.
    """

    head_dim: int = 128
    key_quant_bits: int = 3  # 3-4 = MSE keys, 8 = FP8 keys
    value_quant_bits: int = 4  # 3-4 = uniform quantized values
    seed: int = 42  # kept for backward compatibility; no longer used internally
    norm_correction: bool = False

    @property
    def key_fp8(self) -> bool:
        """Whether keys are stored as FP8 — no rotation/quantization needed."""
        return self.key_quant_bits == 8

    @property
    def mse_bits(self) -> int:
        """MSE quantizer bit-width (determines centroid count: 2^mse_bits).

        For MSE key modes, equals key_quant_bits.
        For FP8 key mode, falls back to value_quant_bits (centroids are still
        needed for continuation-prefill dequant and decode kernel params).
        """
        if self.key_fp8:
            return self.value_quant_bits
        return self.key_quant_bits

    @property
    def key_mse_bits(self) -> int:
        """MSE bits actually used for key quantization (0 if FP8 keys)."""
        if self.key_fp8:
            return 0
        return self.key_quant_bits

    @property
    def centroid_bits(self) -> int:
        """Bits for centroid generation — always non-zero."""
        return self.mse_bits

    @property
    def n_centroids(self) -> int:
        return 2**self.mse_bits

    @property
    def key_packed_size(self) -> int:
        """Packed bytes for a single KEY vector.

        FP8 mode (key_quant_bits=8):
          head_dim bytes (1 byte per element, no overhead).

        TQ mode:
          - MSE indices: ceil(head_dim * key_mse_bits / 8) bytes
          - vec_norm:     2 bytes (float16)
        """
        if self.key_fp8:
            return self.head_dim  # 1 byte per element
        mse_bytes = math.ceil(self.head_dim * self.key_mse_bits / 8)
        norm_bytes = 2  # vec_norm fp16
        return mse_bytes + norm_bytes

    @property
    def effective_value_quant_bits(self) -> int:
        """Actual bits used for value storage."""
        return self.value_quant_bits

    @property
    def value_packed_size(self) -> int:
        """Packed bytes for a single VALUE vector.

        Uniform quantization: ceil(head_dim * bits / 8) + 4 bytes (scale + zero fp16).
        """
        data_bytes = math.ceil(self.head_dim * self.value_quant_bits / 8)
        return data_bytes + 4  # +2 scale(fp16) +2 zero(fp16)

    @property
    def slot_size(self) -> int:
        """Total packed bytes per head per position (key + value combined).

        Layout: [key_packed | value_packed]
        """
        return self.key_packed_size + self.value_packed_size

    @property
    def slot_size_aligned(self) -> int:
        """Slot size rounded up to next even number.

        Even-number is required so effective_head_size = slot_size_aligned // 2
        is integral.
        """
        s = self.slot_size
        return s + (s % 2)  # round up to even

    @staticmethod
    def get_boundary_skip_layers(
        model_config: ModelConfig,
        n: int = 2,
    ) -> list[str]:
        """Layer indices to skip TQ compression (boundary protection).

        For hybrid models (attention + Mamba/linear-attention), boundary
        protection is disabled — hybrids typically have only 8-12
        full-attention layers and a hard n=2 on each side would cover
        ~40 % of them.  The dense GSM8K baselines that motivate n=2
        don't apply to hybrids.

        For dense models, skips first N and last N attention layers.
        Empirically required for aggressive presets (k3v4_nc, 3bit_nc)
        — without it GSM8K drops ~30 points on Qwen3-4B.
        """
        if model_config.is_hybrid:
            attn_indices = _get_full_attention_layer_indices(model_config)
            if not attn_indices:
                raise NotImplementedError(
                    "TurboQuant KV cache requires identifiable "
                    "full-attention layers, but none were found in "
                    "the hybrid model config."
                )
            logger.info("TQ hybrid: full-attention layers %s", attn_indices)
            return []

        num_layers = model_config.hf_text_config.num_hidden_layers
        if n <= 0 or num_layers <= 0:
            return []
        n = min(n, num_layers // 2)  # don't skip more than half
        first = list(range(n))
        last = list(range(num_layers - n, num_layers))
        # Deduplicate (if num_layers <= 2*n)
        indices = sorted(set(first + last))
        return [str(i) for i in indices]

    @staticmethod
    def from_cache_dtype(cache_dtype: str, head_dim: int) -> TurboQuantConfig:
        """Create config from a named preset.

        Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.
        """
        if cache_dtype not in TQ_PRESETS:
            valid = ", ".join(TQ_PRESETS.keys())
            raise ValueError(
                f"Unknown TurboQuant cache dtype: {cache_dtype!r}. "
                f"Valid presets: {valid}"
            )
        preset = TQ_PRESETS[cache_dtype]
        return TurboQuantConfig(
            head_dim=head_dim,
            key_quant_bits=preset["key_quant_bits"],
            value_quant_bits=preset["value_quant_bits"],
            norm_correction=preset["norm_correction"],
        )

centroid_bits property

centroid_bits: int

Bits for centroid generation — always non-zero.

effective_value_quant_bits property

effective_value_quant_bits: int

Actual bits used for value storage.

key_fp8 property

key_fp8: bool

Whether keys are stored as FP8 — no rotation/quantization needed.

key_mse_bits property

key_mse_bits: int

MSE bits actually used for key quantization (0 if FP8 keys).

key_packed_size property

key_packed_size: int

Packed bytes for a single KEY vector.

FP8 mode (key_quant_bits=8): head_dim bytes (1 byte per element, no overhead).

TQ mode
  • MSE indices: ceil(head_dim * key_mse_bits / 8) bytes
  • vec_norm: 2 bytes (float16)

mse_bits property

mse_bits: int

MSE quantizer bit-width (determines centroid count: 2^mse_bits).

For MSE key modes, equals key_quant_bits. For FP8 key mode, falls back to value_quant_bits (centroids are still needed for continuation-prefill dequant and decode kernel params).

slot_size property

slot_size: int

Total packed bytes per head per position (key + value combined).

Layout: [key_packed | value_packed]

slot_size_aligned property

slot_size_aligned: int

Slot size rounded up to next even number.

Even-number is required so effective_head_size = slot_size_aligned // 2 is integral.

value_packed_size property

value_packed_size: int

Packed bytes for a single VALUE vector.

Uniform quantization: ceil(head_dim * bits / 8) + 4 bytes (scale + zero fp16).

from_cache_dtype staticmethod

from_cache_dtype(
    cache_dtype: str, head_dim: int
) -> TurboQuantConfig

Create config from a named preset.

Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.

Source code in vllm/model_executor/layers/quantization/turboquant/config.py
@staticmethod
def from_cache_dtype(cache_dtype: str, head_dim: int) -> TurboQuantConfig:
    """Create config from a named preset.

    Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.
    """
    if cache_dtype not in TQ_PRESETS:
        valid = ", ".join(TQ_PRESETS.keys())
        raise ValueError(
            f"Unknown TurboQuant cache dtype: {cache_dtype!r}. "
            f"Valid presets: {valid}"
        )
    preset = TQ_PRESETS[cache_dtype]
    return TurboQuantConfig(
        head_dim=head_dim,
        key_quant_bits=preset["key_quant_bits"],
        value_quant_bits=preset["value_quant_bits"],
        norm_correction=preset["norm_correction"],
    )

get_boundary_skip_layers staticmethod

get_boundary_skip_layers(
    model_config: ModelConfig, n: int = 2
) -> list[str]

Layer indices to skip TQ compression (boundary protection).

For hybrid models (attention + Mamba/linear-attention), boundary protection is disabled — hybrids typically have only 8-12 full-attention layers and a hard n=2 on each side would cover ~40 % of them. The dense GSM8K baselines that motivate n=2 don't apply to hybrids.

For dense models, skips first N and last N attention layers. Empirically required for aggressive presets (k3v4_nc, 3bit_nc) — without it GSM8K drops ~30 points on Qwen3-4B.

Source code in vllm/model_executor/layers/quantization/turboquant/config.py
@staticmethod
def get_boundary_skip_layers(
    model_config: ModelConfig,
    n: int = 2,
) -> list[str]:
    """Layer indices to skip TQ compression (boundary protection).

    For hybrid models (attention + Mamba/linear-attention), boundary
    protection is disabled — hybrids typically have only 8-12
    full-attention layers and a hard n=2 on each side would cover
    ~40 % of them.  The dense GSM8K baselines that motivate n=2
    don't apply to hybrids.

    For dense models, skips first N and last N attention layers.
    Empirically required for aggressive presets (k3v4_nc, 3bit_nc)
    — without it GSM8K drops ~30 points on Qwen3-4B.
    """
    if model_config.is_hybrid:
        attn_indices = _get_full_attention_layer_indices(model_config)
        if not attn_indices:
            raise NotImplementedError(
                "TurboQuant KV cache requires identifiable "
                "full-attention layers, but none were found in "
                "the hybrid model config."
            )
        logger.info("TQ hybrid: full-attention layers %s", attn_indices)
        return []

    num_layers = model_config.hf_text_config.num_hidden_layers
    if n <= 0 or num_layers <= 0:
        return []
    n = min(n, num_layers // 2)  # don't skip more than half
    first = list(range(n))
    last = list(range(num_layers - n, num_layers))
    # Deduplicate (if num_layers <= 2*n)
    indices = sorted(set(first + last))
    return [str(i) for i in indices]

_get_full_attention_layer_indices

_get_full_attention_layer_indices(
    model_config: ModelConfig,
) -> list[int]

Global indices of full-attention layers in a hybrid model.

Covers the conventions used across vLLM: layer_types (Qwen3.5/Next), layers_block_type (Jamba/Zamba2), attn_type_list (Minimax).

Source code in vllm/model_executor/layers/quantization/turboquant/config.py
def _get_full_attention_layer_indices(model_config: ModelConfig) -> list[int]:
    """Global indices of full-attention layers in a hybrid model.

    Covers the conventions used across vLLM: ``layer_types`` (Qwen3.5/Next),
    ``layers_block_type`` (Jamba/Zamba2), ``attn_type_list`` (Minimax).
    """
    text_cfg = model_config.hf_text_config
    hf_cfg = model_config.hf_config

    layer_types = getattr(text_cfg, "layer_types", None)
    if layer_types is not None:
        return [
            i for i, t in enumerate(layer_types) if t in ("full_attention", "attention")
        ]

    layers_block_type = getattr(text_cfg, "layers_block_type", None)
    if layers_block_type is not None:
        return [
            i for i, t in enumerate(layers_block_type) if t in ("attention", "hybrid")
        ]

    attn_type_list = getattr(hf_cfg, "attn_type_list", None)
    if attn_type_list is not None:
        return [i for i, t in enumerate(attn_type_list) if t == 1]

    return []