Skip to content

vllm.compilation.passes.fusion.matcher_utils

MatcherCustomOp

Bases: ABC

Source code in vllm/compilation/passes/fusion/matcher_utils.py
class MatcherCustomOp(ABC):
    def __init__(self, enabled: bool) -> None:
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None
        self.device = config.device_config.device if config.device_config else None

        self.enabled = enabled
        self.forward = self.forward_custom if enabled else self.forward_native

    @abstractmethod
    def forward_custom(self, *args: Any, **kwargs: Any) -> Any:
        pass

    @abstractmethod
    def forward_native(self, *args: Any, **kwargs: Any) -> Any:
        pass

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return self.forward(*args, **kwargs)

    def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
        return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)

    def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor:
        return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs)

    def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
        return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)

    def inputs(self) -> list[torch.Tensor]:
        """Utility for inputs to the pattern"""
        raise NotImplementedError

inputs

inputs() -> list[Tensor]

Utility for inputs to the pattern

Source code in vllm/compilation/passes/fusion/matcher_utils.py
def inputs(self) -> list[torch.Tensor]:
    """Utility for inputs to the pattern"""
    raise NotImplementedError

MatcherRMSNormGated

Bases: MatcherCustomOp

Matches RMSNormGated with norm_before_gate=True and group_size=None.

Source code in vllm/compilation/passes/fusion/matcher_utils.py
class MatcherRMSNormGated(MatcherCustomOp):
    """Matches RMSNormGated with norm_before_gate=True and group_size=None."""

    def __init__(
        self,
        epsilon: float,
        enabled: bool | None = None,
        norm_before_gate: bool = True,
        group_size: int | None = None,
    ) -> None:
        if enabled is None:
            enabled = RMSNormGated.enabled()

        super().__init__(enabled)
        self.epsilon = epsilon
        self.norm_before_gate = norm_before_gate
        self.group_size = group_size

    def inputs(self) -> list[torch.Tensor]:
        x = self.empty(5, 16)
        z = self.empty(5, 16)
        weight = self.empty(16)
        return [x, z, weight]

    def forward_custom(
        self,
        x: torch.Tensor,
        z: torch.Tensor,
        weight: torch.Tensor,
    ) -> torch.Tensor:
        from vllm.model_executor.layers.fla.ops.layernorm_guard import (
            rmsnorm_fn,
        )

        return rmsnorm_fn(
            x,
            weight,
            bias=None,
            z=z,
            eps=self.epsilon,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )

    def forward_native(
        self,
        x: torch.Tensor,
        z: torch.Tensor,
        weight: torch.Tensor,
    ) -> torch.Tensor:
        return RMSNormGated.forward_static(
            x,
            z,
            weight,
            self.epsilon,
            self.model_dtype,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )