class MatcherRMSNormGated(MatcherCustomOp):
"""Matches RMSNormGated with norm_before_gate=True and group_size=None."""
def __init__(
self,
epsilon: float,
enabled: bool | None = None,
norm_before_gate: bool = True,
group_size: int | None = None,
) -> None:
if enabled is None:
enabled = RMSNormGated.enabled()
super().__init__(enabled)
self.epsilon = epsilon
self.norm_before_gate = norm_before_gate
self.group_size = group_size
def inputs(self) -> list[torch.Tensor]:
x = self.empty(5, 16)
z = self.empty(5, 16)
weight = self.empty(16)
return [x, z, weight]
def forward_custom(
self,
x: torch.Tensor,
z: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
from vllm.model_executor.layers.fla.ops.layernorm_guard import (
rmsnorm_fn,
)
return rmsnorm_fn(
x,
weight,
bias=None,
z=z,
eps=self.epsilon,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
)
def forward_native(
self,
x: torch.Tensor,
z: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return RMSNormGated.forward_static(
x,
z,
weight,
self.epsilon,
self.model_dtype,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
)