@MULTIMODAL_REGISTRY.register_processor(
Step3VLMultiModalProcessor,
info=Step3VLProcessingInfo,
dummy_inputs=Step3VLDummyInputsBuilder,
)
class Step3VLForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsEncoderCudaGraph
):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.": "language_model.model.",
"lm_head.": "language_model.lm_head.",
}
)
supports_encoder_tp_data = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<im_patch>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.model_config = vllm_config.model_config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
# NOTE: This behavior is consistent with the previous OOV handling,
# but does not currently handle the start/stop toks around the
# image features (<patch_start> <patch_end> <im_start> <im_end>)
# See: https://huggingface.co/stepfun-ai/step3/blob/main/processing_step3v.py#L323
#
# If this becomes an issue or we refactor to handle this using the
# processor info in the future, it would probably be best to handle
# those too.
self.configure_mm_token_handling(
self.config.text_config.vocab_size,
[self.config.image_token_id],
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = Step3VisionTransformer(
config.vision_config,
None,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.vit_downsampler = Conv2dLayer(
config.vision_config.hidden_size,
config.vision_config.output_hidden_size,
kernel_size=2,
stride=config.understand_projector_stride,
)
self.vit_downsampler2 = Conv2dLayer(
config.vision_config.output_hidden_size,
config.vision_config.output_hidden_size * 2,
kernel_size=3,
stride=2,
padding=1,
)
self.vit_large_projector = nn.Linear(
config.vision_config.output_hidden_size * 2,
config.hidden_size,
bias=config.projector_bias,
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
@staticmethod
def _compute_spatial_tokens(size, patch_size, stride):
# Compute the number of spatial tokens after two rounds of
# downsampling with given patch size and stride.
grid = size // patch_size
vit_tokens = grid * grid
spatial = int(math.sqrt(vit_tokens))
h1 = (spatial - 2) // stride + 1
h2 = (h1 - 1) // 2 + 1
return h2 * h2
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Step3VLImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
patch_pixel_values = kwargs.pop("patch_pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None and patch_pixel_values is not None:
return Step3VLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values.to(self.dtype),
patch_pixel_values=patch_pixel_values.to(self.dtype),
num_patches=num_patches,
)
if image_embeds is not None:
return Step3VLImageEmbeddingInputs(
type="image_embeds",
data=image_embeds.to(self.dtype),
)
raise AssertionError("This line should be unreachable.")
def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
B, P = image_features.shape[:2]
HW = int(sqrt(P))
image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
image_features = self.vit_downsampler(image_features)
image_features = self.vit_downsampler2(image_features)
n_dim = image_features.size(1)
image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
image_features = self.vit_large_projector(image_features)
return image_features
def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor:
return self.vision_model(input_tensor)[:, 4:]
def _process_image_input(
self, image_input: Step3VLImageInputs
) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
image_features = image_input["data"]
return [
image_features[i].view(-1, image_features.shape[-1])
for i in range(image_features.shape[0])
]
image_features = self._get_vision_model_output(image_input["pixel_values"])
patch_image_features = (
self._get_vision_model_output(image_input["patch_pixel_values"])
if len(image_input["patch_pixel_values"]) > 0
else None
)
num_patches = image_input["num_patches"]
image_features = self._process_image_features(image_features)
patch_image_features = (
self._process_image_features(patch_image_features)
if patch_image_features is not None
else None
)
merged_image_features = []
cur_patch_idx = 0
for i, num_patch in enumerate(num_patches):
cur_feature = []
if num_patch > 0:
patch_slice = patch_image_features[
cur_patch_idx : cur_patch_idx + num_patch
]
cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
cur_feature.append(image_features[i].view(-1, image_features.shape[-1]))
cur_patch_idx += num_patch
merged_image_features.append(
torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0]
)
return merged_image_features
def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
def get_encoder_cudagraph_config(self):
from vllm.v1.worker.encoder_cudagraph_defs import (
EncoderCudaGraphConfig,
)
return EncoderCudaGraphConfig(
modalities=["image"],
input_key_by_modality={"image": "pixel_values"},
buffer_keys=["patch_pixel_values"],
out_hidden_size=self.config.hidden_size,
)
def get_input_modality(
self,
mm_kwargs: dict[str, Any],
) -> str:
return "image"
def get_max_frames_per_video(
self,
) -> int:
return 0
def get_encoder_cudagraph_budget_range(
self,
vllm_config: "VllmConfig",
) -> tuple[int, int]:
# An image without patches
min_budget = self._compute_spatial_tokens(
self.config.vision_config.image_size,
self.config.vision_config.patch_size,
self.config.understand_projector_stride,
)
max_budget = min(
vllm_config.scheduler_config.max_num_batched_tokens,
self.model_config.max_model_len,
)
return min_budget, max_budget
def get_encoder_cudagraph_num_items(
self,
mm_kwargs: dict[str, Any],
) -> int:
return len(mm_kwargs.get("pixel_values", []))
def get_encoder_cudagraph_per_item_output_tokens(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
num_patches = mm_kwargs.get("num_patches")
img_output_tokens = self._compute_spatial_tokens(
self.config.vision_config.image_size,
self.config.vision_config.patch_size,
self.config.understand_projector_stride,
)
patch_output_tokens = self._compute_spatial_tokens(
504,
self.config.vision_config.patch_size,
self.config.understand_projector_stride,
)
return [
img_output_tokens + num_patch * patch_output_tokens
for num_patch in num_patches
]
def get_encoder_cudagraph_per_item_input_sizes(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
img_grid = (
self.config.vision_config.image_size // self.config.vision_config.patch_size
)
# NOTE: 504 is the hard coded size for each patch after processing
# by the vision model, which is determined by the current architecture
# of the vision model and may need to be updated if the architecture changes.
# The number of tokens for each patch is calculated based on this
# size and the patch size.
patch_grid = 504 // self.config.vision_config.patch_size
total_image_pixel = img_grid * img_grid
total_patch_pixel = patch_grid * patch_grid
num_patches = mm_kwargs.get("num_patches")
return [
total_image_pixel + num_patch * total_patch_pixel
for num_patch in num_patches
]
def select_encoder_cudagraph_items(
self,
mm_kwargs: dict[str, Any],
indices: list[int],
) -> dict[str, Any]:
pixel_values = mm_kwargs["pixel_values"]
patch_pixel_values = mm_kwargs["patch_pixel_values"]
num_patches = mm_kwargs["num_patches"]
# calcute the accumulated patch counts
cum_patches = [0]
for p in num_patches:
cum_patches.append(cum_patches[-1] + p)
if len(indices) == 0:
return {
"pixel_values": pixel_values[:0],
"patch_pixel_values": patch_pixel_values[:0],
"num_patches": num_patches[:0],
}
selected_pv = pixel_values[indices]
selected_np = num_patches[indices]
selected_ppv = torch.cat(
[patch_pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
)
return {
"pixel_values": selected_pv,
"patch_pixel_values": selected_ppv,
"num_patches": selected_np,
}
def prepare_encoder_cudagraph_capture_inputs(
self,
token_budget: int,
max_batch_size: int,
max_frames_per_batch: int,
device: torch.device,
dtype: torch.dtype,
):
from vllm.v1.worker.encoder_cudagraph_defs import (
EncoderCudaGraphCaptureInputs,
)
# For pixel_value, the max input size is max_batch_size
img_output_tokens = self._compute_spatial_tokens(
self.config.vision_config.image_size,
self.config.vision_config.patch_size,
self.config.understand_projector_stride,
)
patch_output_tokens = self._compute_spatial_tokens(
504,
self.config.vision_config.patch_size,
self.config.understand_projector_stride,
)
dummy_pixel_values = torch.randn(
max_batch_size,
3,
self.config.vision_config.image_size,
self.config.vision_config.image_size,
device=device,
dtype=dtype,
)
# max_num_patches is the max total patches across the whole batch.
# token_budget = max_batch_size * img_out + max_num_patches * patch_out
max_num_patches = max(
0,
(token_budget - max_batch_size * img_output_tokens) // patch_output_tokens,
)
dummy_patch_pixel_values = torch.randn(
max_num_patches,
3,
504,
504,
device=device,
dtype=dtype,
)
# num_patches is NOT in buffers -- the per-item merge is done
# CPU-side by finalize_encoder_cudagraph_output using the actual
# batch's num_patches from mm_kwargs.
mm_kwargs = {
"pixel_values": dummy_pixel_values,
"patch_pixel_values": dummy_patch_pixel_values,
}
buffers = {
"patch_pixel_values": dummy_patch_pixel_values,
}
return EncoderCudaGraphCaptureInputs(
mm_kwargs=mm_kwargs,
buffers=buffers,
)
def encoder_cudagraph_forward(
self,
mm_kwargs: dict[str, Any],
buffers: dict[str, torch.Tensor],
) -> torch.Tensor:
# Graph captures only the compute (vision model + conv projector).
# Per-item merge happens CPU-side in finalize_encoder_cudagraph_output
# using actual num_patches from the batch data.
pixel_values = mm_kwargs["pixel_values"]
patch_pixel_values = buffers["patch_pixel_values"]
image_features = self._process_image_features(
self._get_vision_model_output(pixel_values)
)
has_patches = len(patch_pixel_values) > 0
if has_patches:
patch_features = self._process_image_features(
self._get_vision_model_output(patch_pixel_values)
)
# Deterministic single cat: [all_img_flat, all_patch_flat]
img_flat = image_features.reshape(-1, image_features.shape[-1])
if has_patches:
patch_flat = patch_features.reshape(-1, patch_features.shape[-1])
return torch.cat([img_flat, patch_flat], dim=0)
return img_flat
def encoder_eager_forward(
self,
mm_kwargs: dict[str, Any],
) -> torch.Tensor:
image_input = Step3VLImagePixelInputs(
type="pixel_values",
pixel_values=mm_kwargs["pixel_values"],
patch_pixel_values=mm_kwargs["patch_pixel_values"],
num_patches=mm_kwargs["num_patches"],
)
vision_embeddings = self._process_image_input(image_input)
return torch.cat(vision_embeddings, dim=0)
def postprocess_encoder_output(
self,
output: torch.Tensor,
indices: list[int],
per_item_out_tokens: list[int],
dest: dict[int, torch.Tensor] | list[torch.Tensor | None],
clone: bool = False,
batch_mm_kwargs: dict[str, Any] | None = None,
):
"""CPU-side per-item merge after graph replay.
The graph output is ``[all_img_flat, all_patch_flat]``.
This method splits the flat output into image and patch features,
then reassembles per-item embeddings using the *actual* batch
``num_patches`` from ``batch_mm_kwargs`` (not the capture-time values).
"""
num_patches = batch_mm_kwargs["num_patches"]
hidden = output.shape[-1]
bsz = len(indices)
img_out = self._compute_spatial_tokens(
self.config.vision_config.image_size,
self.config.vision_config.patch_size,
self.config.understand_projector_stride,
)
patch_out = self._compute_spatial_tokens(
504,
self.config.vision_config.patch_size,
self.config.understand_projector_stride,
)
# Valid portion: bsz images, actual_total_patches patches
actual_np = [int(np) for np in num_patches]
total_patches = sum(actual_np)
img_tokens = bsz * img_out
patch_tokens = total_patches * patch_out
img_part = output[:img_tokens].reshape(bsz, img_out, hidden)
if total_patches > 0:
patch_part = output[img_tokens : img_tokens + patch_tokens].reshape(
-1, patch_out, hidden
)
else:
patch_part = None
merged: dict[int, torch.Tensor] = {}
cur_patch = 0
for i, idx in enumerate(indices):
np = actual_np[i]
parts: list[torch.Tensor] = []
if patch_part is not None and np > 0:
parts.append(patch_part[cur_patch : cur_patch + np].reshape(-1, hidden))
cur_patch += np
parts.append(img_part[i].reshape(-1, hidden))
merged[idx] = torch.cat(parts, dim=0) if len(parts) > 1 else parts[0]
out = [merged[i] for i in indices]
for i, idx in enumerate(indices):
dest[idx] = out[i]
def prepare_encoder_cudagraph_replay_buffers(
self,
mm_kwargs: dict[str, Any],
max_batch_size: int,
max_frames_per_batch: int,
):
from vllm.v1.worker.encoder_cudagraph_defs import (
EncoderCudaGraphReplayBuffers,
)
# Only patch_pixel_values lives in the buffers dict; num_patches is
# processed CPU-side by finalize_encoder_cudagraph_output.
return EncoderCudaGraphReplayBuffers(
buffers={
"patch_pixel_values": mm_kwargs["patch_pixel_values"],
},
)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)