Skip to content

vllm.model_executor.models.step3_vl

Step3VLForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsPP, SupportsEncoderCudaGraph

Source code in vllm/model_executor/models/step3_vl.py
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
@MULTIMODAL_REGISTRY.register_processor(
    Step3VLMultiModalProcessor,
    info=Step3VLProcessingInfo,
    dummy_inputs=Step3VLDummyInputsBuilder,
)
class Step3VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsPP, SupportsEncoderCudaGraph
):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
        }
    )

    supports_encoder_tp_data = True

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return "<im_patch>"

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.model_config = vllm_config.model_config
        self.multimodal_config = multimodal_config
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

        # NOTE: This behavior is consistent with the previous OOV handling,
        # but does not currently handle the start/stop toks around the
        # image features (<patch_start> <patch_end> <im_start> <im_end>)
        # See: https://huggingface.co/stepfun-ai/step3/blob/main/processing_step3v.py#L323
        #
        # If this becomes an issue or we refactor to handle this using the
        # processor info in the future, it would probably be best to handle
        # those too.
        self.configure_mm_token_handling(
            self.config.text_config.vocab_size,
            [self.config.image_token_id],
        )

        with self._mark_tower_model(vllm_config, "image"):
            self.vision_model = Step3VisionTransformer(
                config.vision_config,
                None,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
            self.vit_downsampler = Conv2dLayer(
                config.vision_config.hidden_size,
                config.vision_config.output_hidden_size,
                kernel_size=2,
                stride=config.understand_projector_stride,
            )
            self.vit_downsampler2 = Conv2dLayer(
                config.vision_config.output_hidden_size,
                config.vision_config.output_hidden_size * 2,
                kernel_size=3,
                stride=2,
                padding=1,
            )
            self.vit_large_projector = nn.Linear(
                config.vision_config.output_hidden_size * 2,
                config.hidden_size,
                bias=config.projector_bias,
            )

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @staticmethod
    def _compute_spatial_tokens(size, patch_size, stride):
        # Compute the number of spatial tokens after two rounds of
        # downsampling with given patch size and stride.
        grid = size // patch_size
        vit_tokens = grid * grid
        spatial = int(math.sqrt(vit_tokens))
        h1 = (spatial - 2) // stride + 1
        h2 = (h1 - 1) // 2 + 1
        return h2 * h2

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> Step3VLImageInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        patch_pixel_values = kwargs.pop("patch_pixel_values", None)
        num_patches = kwargs.pop("num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None and patch_pixel_values is not None:
            return Step3VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values.to(self.dtype),
                patch_pixel_values=patch_pixel_values.to(self.dtype),
                num_patches=num_patches,
            )

        if image_embeds is not None:
            return Step3VLImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds.to(self.dtype),
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
        B, P = image_features.shape[:2]
        HW = int(sqrt(P))
        image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
        image_features = self.vit_downsampler(image_features)
        image_features = self.vit_downsampler2(image_features)
        n_dim = image_features.size(1)
        image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
        image_features = self.vit_large_projector(image_features)
        return image_features

    def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor:
        return self.vision_model(input_tensor)[:, 4:]

    def _process_image_input(
        self, image_input: Step3VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
        if image_input["type"] == "image_embeds":
            image_features = image_input["data"]
            return [
                image_features[i].view(-1, image_features.shape[-1])
                for i in range(image_features.shape[0])
            ]

        image_features = self._get_vision_model_output(image_input["pixel_values"])
        patch_image_features = (
            self._get_vision_model_output(image_input["patch_pixel_values"])
            if len(image_input["patch_pixel_values"]) > 0
            else None
        )
        num_patches = image_input["num_patches"]

        image_features = self._process_image_features(image_features)
        patch_image_features = (
            self._process_image_features(patch_image_features)
            if patch_image_features is not None
            else None
        )

        merged_image_features = []
        cur_patch_idx = 0
        for i, num_patch in enumerate(num_patches):
            cur_feature = []
            if num_patch > 0:
                patch_slice = patch_image_features[
                    cur_patch_idx : cur_patch_idx + num_patch
                ]
                cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
            cur_feature.append(image_features[i].view(-1, image_features.shape[-1]))
            cur_patch_idx += num_patch
            merged_image_features.append(
                torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0]
            )
        return merged_image_features

    def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
            return super().embed_input_ids(input_ids)

        return super().embed_input_ids(
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )

    def get_encoder_cudagraph_config(self):
        from vllm.v1.worker.encoder_cudagraph_defs import (
            EncoderCudaGraphConfig,
        )

        return EncoderCudaGraphConfig(
            modalities=["image"],
            input_key_by_modality={"image": "pixel_values"},
            buffer_keys=["patch_pixel_values"],
            out_hidden_size=self.config.hidden_size,
        )

    def get_input_modality(
        self,
        mm_kwargs: dict[str, Any],
    ) -> str:
        return "image"

    def get_max_frames_per_video(
        self,
    ) -> int:
        return 0

    def get_encoder_cudagraph_budget_range(
        self,
        vllm_config: "VllmConfig",
    ) -> tuple[int, int]:
        # An image without patches
        min_budget = self._compute_spatial_tokens(
            self.config.vision_config.image_size,
            self.config.vision_config.patch_size,
            self.config.understand_projector_stride,
        )
        max_budget = min(
            vllm_config.scheduler_config.max_num_batched_tokens,
            self.model_config.max_model_len,
        )
        return min_budget, max_budget

    def get_encoder_cudagraph_num_items(
        self,
        mm_kwargs: dict[str, Any],
    ) -> int:
        return len(mm_kwargs.get("pixel_values", []))

    def get_encoder_cudagraph_per_item_output_tokens(
        self,
        mm_kwargs: dict[str, Any],
    ) -> list[int]:
        num_patches = mm_kwargs.get("num_patches")
        img_output_tokens = self._compute_spatial_tokens(
            self.config.vision_config.image_size,
            self.config.vision_config.patch_size,
            self.config.understand_projector_stride,
        )
        patch_output_tokens = self._compute_spatial_tokens(
            504,
            self.config.vision_config.patch_size,
            self.config.understand_projector_stride,
        )
        return [
            img_output_tokens + num_patch * patch_output_tokens
            for num_patch in num_patches
        ]

    def get_encoder_cudagraph_per_item_input_sizes(
        self,
        mm_kwargs: dict[str, Any],
    ) -> list[int]:
        img_grid = (
            self.config.vision_config.image_size // self.config.vision_config.patch_size
        )

        # NOTE: 504 is the hard coded size for each patch after processing
        # by the vision model, which is determined by the current architecture
        # of the vision model and may need to be updated if the architecture changes.
        # The number of tokens for each patch is calculated based on this
        # size and the patch size.
        patch_grid = 504 // self.config.vision_config.patch_size
        total_image_pixel = img_grid * img_grid
        total_patch_pixel = patch_grid * patch_grid

        num_patches = mm_kwargs.get("num_patches")
        return [
            total_image_pixel + num_patch * total_patch_pixel
            for num_patch in num_patches
        ]

    def select_encoder_cudagraph_items(
        self,
        mm_kwargs: dict[str, Any],
        indices: list[int],
    ) -> dict[str, Any]:
        pixel_values = mm_kwargs["pixel_values"]
        patch_pixel_values = mm_kwargs["patch_pixel_values"]
        num_patches = mm_kwargs["num_patches"]

        # calcute the accumulated patch counts
        cum_patches = [0]
        for p in num_patches:
            cum_patches.append(cum_patches[-1] + p)

        if len(indices) == 0:
            return {
                "pixel_values": pixel_values[:0],
                "patch_pixel_values": patch_pixel_values[:0],
                "num_patches": num_patches[:0],
            }

        selected_pv = pixel_values[indices]
        selected_np = num_patches[indices]
        selected_ppv = torch.cat(
            [patch_pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
        )

        return {
            "pixel_values": selected_pv,
            "patch_pixel_values": selected_ppv,
            "num_patches": selected_np,
        }

    def prepare_encoder_cudagraph_capture_inputs(
        self,
        token_budget: int,
        max_batch_size: int,
        max_frames_per_batch: int,
        device: torch.device,
        dtype: torch.dtype,
    ):
        from vllm.v1.worker.encoder_cudagraph_defs import (
            EncoderCudaGraphCaptureInputs,
        )

        # For pixel_value, the max input size is max_batch_size
        img_output_tokens = self._compute_spatial_tokens(
            self.config.vision_config.image_size,
            self.config.vision_config.patch_size,
            self.config.understand_projector_stride,
        )
        patch_output_tokens = self._compute_spatial_tokens(
            504,
            self.config.vision_config.patch_size,
            self.config.understand_projector_stride,
        )
        dummy_pixel_values = torch.randn(
            max_batch_size,
            3,
            self.config.vision_config.image_size,
            self.config.vision_config.image_size,
            device=device,
            dtype=dtype,
        )
        # max_num_patches is the max total patches across the whole batch.
        # token_budget = max_batch_size * img_out + max_num_patches * patch_out
        max_num_patches = max(
            0,
            (token_budget - max_batch_size * img_output_tokens) // patch_output_tokens,
        )
        dummy_patch_pixel_values = torch.randn(
            max_num_patches,
            3,
            504,
            504,
            device=device,
            dtype=dtype,
        )
        # num_patches is NOT in buffers -- the per-item merge is done
        # CPU-side by finalize_encoder_cudagraph_output using the actual
        # batch's num_patches from mm_kwargs.
        mm_kwargs = {
            "pixel_values": dummy_pixel_values,
            "patch_pixel_values": dummy_patch_pixel_values,
        }

        buffers = {
            "patch_pixel_values": dummy_patch_pixel_values,
        }

        return EncoderCudaGraphCaptureInputs(
            mm_kwargs=mm_kwargs,
            buffers=buffers,
        )

    def encoder_cudagraph_forward(
        self,
        mm_kwargs: dict[str, Any],
        buffers: dict[str, torch.Tensor],
    ) -> torch.Tensor:
        # Graph captures only the compute (vision model + conv projector).
        # Per-item merge happens CPU-side in finalize_encoder_cudagraph_output
        # using actual num_patches from the batch data.
        pixel_values = mm_kwargs["pixel_values"]
        patch_pixel_values = buffers["patch_pixel_values"]

        image_features = self._process_image_features(
            self._get_vision_model_output(pixel_values)
        )

        has_patches = len(patch_pixel_values) > 0
        if has_patches:
            patch_features = self._process_image_features(
                self._get_vision_model_output(patch_pixel_values)
            )

        # Deterministic single cat: [all_img_flat, all_patch_flat]
        img_flat = image_features.reshape(-1, image_features.shape[-1])
        if has_patches:
            patch_flat = patch_features.reshape(-1, patch_features.shape[-1])
            return torch.cat([img_flat, patch_flat], dim=0)
        return img_flat

    def encoder_eager_forward(
        self,
        mm_kwargs: dict[str, Any],
    ) -> torch.Tensor:
        image_input = Step3VLImagePixelInputs(
            type="pixel_values",
            pixel_values=mm_kwargs["pixel_values"],
            patch_pixel_values=mm_kwargs["patch_pixel_values"],
            num_patches=mm_kwargs["num_patches"],
        )
        vision_embeddings = self._process_image_input(image_input)
        return torch.cat(vision_embeddings, dim=0)

    def postprocess_encoder_output(
        self,
        output: torch.Tensor,
        indices: list[int],
        per_item_out_tokens: list[int],
        dest: dict[int, torch.Tensor] | list[torch.Tensor | None],
        clone: bool = False,
        batch_mm_kwargs: dict[str, Any] | None = None,
    ):
        """CPU-side per-item merge after graph replay.

        The graph output is ``[all_img_flat, all_patch_flat]``.
        This method splits the flat output into image and patch features,
        then reassembles per-item embeddings using the *actual* batch
        ``num_patches`` from ``batch_mm_kwargs`` (not the capture-time values).
        """
        num_patches = batch_mm_kwargs["num_patches"]
        hidden = output.shape[-1]
        bsz = len(indices)

        img_out = self._compute_spatial_tokens(
            self.config.vision_config.image_size,
            self.config.vision_config.patch_size,
            self.config.understand_projector_stride,
        )
        patch_out = self._compute_spatial_tokens(
            504,
            self.config.vision_config.patch_size,
            self.config.understand_projector_stride,
        )

        # Valid portion: bsz images, actual_total_patches patches
        actual_np = [int(np) for np in num_patches]
        total_patches = sum(actual_np)
        img_tokens = bsz * img_out
        patch_tokens = total_patches * patch_out

        img_part = output[:img_tokens].reshape(bsz, img_out, hidden)
        if total_patches > 0:
            patch_part = output[img_tokens : img_tokens + patch_tokens].reshape(
                -1, patch_out, hidden
            )
        else:
            patch_part = None

        merged: dict[int, torch.Tensor] = {}
        cur_patch = 0
        for i, idx in enumerate(indices):
            np = actual_np[i]
            parts: list[torch.Tensor] = []
            if patch_part is not None and np > 0:
                parts.append(patch_part[cur_patch : cur_patch + np].reshape(-1, hidden))
                cur_patch += np
            parts.append(img_part[i].reshape(-1, hidden))
            merged[idx] = torch.cat(parts, dim=0) if len(parts) > 1 else parts[0]

        out = [merged[i] for i in indices]
        for i, idx in enumerate(indices):
            dest[idx] = out[i]

    def prepare_encoder_cudagraph_replay_buffers(
        self,
        mm_kwargs: dict[str, Any],
        max_batch_size: int,
        max_frames_per_batch: int,
    ):
        from vllm.v1.worker.encoder_cudagraph_defs import (
            EncoderCudaGraphReplayBuffers,
        )

        # Only patch_pixel_values lives in the buffers dict; num_patches is
        # processed CPU-side by finalize_encoder_cudagraph_output.
        return EncoderCudaGraphReplayBuffers(
            buffers={
                "patch_pixel_values": mm_kwargs["patch_pixel_values"],
            },
        )

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

postprocess_encoder_output

postprocess_encoder_output(
    output: Tensor,
    indices: list[int],
    per_item_out_tokens: list[int],
    dest: dict[int, Tensor] | list[Tensor | None],
    clone: bool = False,
    batch_mm_kwargs: dict[str, Any] | None = None,
)

CPU-side per-item merge after graph replay.

The graph output is [all_img_flat, all_patch_flat]. This method splits the flat output into image and patch features, then reassembles per-item embeddings using the actual batch num_patches from batch_mm_kwargs (not the capture-time values).

Source code in vllm/model_executor/models/step3_vl.py
def postprocess_encoder_output(
    self,
    output: torch.Tensor,
    indices: list[int],
    per_item_out_tokens: list[int],
    dest: dict[int, torch.Tensor] | list[torch.Tensor | None],
    clone: bool = False,
    batch_mm_kwargs: dict[str, Any] | None = None,
):
    """CPU-side per-item merge after graph replay.

    The graph output is ``[all_img_flat, all_patch_flat]``.
    This method splits the flat output into image and patch features,
    then reassembles per-item embeddings using the *actual* batch
    ``num_patches`` from ``batch_mm_kwargs`` (not the capture-time values).
    """
    num_patches = batch_mm_kwargs["num_patches"]
    hidden = output.shape[-1]
    bsz = len(indices)

    img_out = self._compute_spatial_tokens(
        self.config.vision_config.image_size,
        self.config.vision_config.patch_size,
        self.config.understand_projector_stride,
    )
    patch_out = self._compute_spatial_tokens(
        504,
        self.config.vision_config.patch_size,
        self.config.understand_projector_stride,
    )

    # Valid portion: bsz images, actual_total_patches patches
    actual_np = [int(np) for np in num_patches]
    total_patches = sum(actual_np)
    img_tokens = bsz * img_out
    patch_tokens = total_patches * patch_out

    img_part = output[:img_tokens].reshape(bsz, img_out, hidden)
    if total_patches > 0:
        patch_part = output[img_tokens : img_tokens + patch_tokens].reshape(
            -1, patch_out, hidden
        )
    else:
        patch_part = None

    merged: dict[int, torch.Tensor] = {}
    cur_patch = 0
    for i, idx in enumerate(indices):
        np = actual_np[i]
        parts: list[torch.Tensor] = []
        if patch_part is not None and np > 0:
            parts.append(patch_part[cur_patch : cur_patch + np].reshape(-1, hidden))
            cur_patch += np
        parts.append(img_part[i].reshape(-1, hidden))
        merged[idx] = torch.cat(parts, dim=0) if len(parts) > 1 else parts[0]

    out = [merged[i] for i in indices]
    for i, idx in enumerate(indices):
        dest[idx] = out[i]

Step3VLImageEmbeddingInputs

Bases: TensorSchema

Dimensions
  • bn: Batch size * number of images
  • f: Image feature size
  • h: Hidden size (must match the hidden size of language model backbone)
Source code in vllm/model_executor/models/step3_vl.py
class Step3VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - f: Image feature size
        - h: Hidden size (must match the hidden size of language model backbone)
    """

    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]

Step3VLImagePixelInputs

Bases: TensorSchema

Dimensions
  • bn: Batch size * number of images
  • c: Number of channels (3)
  • h: Height
  • w: Width
  • bnp: Batch size * number of images * number of patches
  • hp: Height of patch
  • wp: Width of patch
Source code in vllm/model_executor/models/step3_vl.py
class Step3VLImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
        - bnp: Batch size * number of images * number of patches
        - hp: Height of patch
        - wp: Width of patch
    """

    type: Literal["pixel_values"]
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
    patch_pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "hp", "wp")]
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]

Step3VisionAttention

Bases: Module

Multi-headed attention from 'Attention Is All You Need' paper

Source code in vllm/model_executor/models/step3_vl.py
class Step3VisionAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        config,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.total_num_heads

        self.scale = self.head_dim**-0.5

        use_data_parallel = is_vit_use_data_parallel()
        tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size

        self.q_size = self.num_heads * self.head_dim

        self.qkv_proj = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
            self.total_num_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
            disable_tp=use_data_parallel,
        )
        self.out_proj = RowParallelLinear(
            self.embed_dim,
            self.embed_dim,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
            disable_tp=use_data_parallel,
        )

        # Use unified MMEncoderAttention with automatic backend selection
        self.attn = MMEncoderAttention(
            self.num_heads,
            self.head_dim,
            self.scale,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""
        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)

        # Use unified MMEncoderAttention with automatic backend selection
        attn_output = self.attn(q, k, v)

        attn_output, _ = self.out_proj(attn_output)

        return attn_output

forward

forward(hidden_states: Tensor)

Input shape: Batch x Time x Channel

Source code in vllm/model_executor/models/step3_vl.py
def forward(
    self,
    hidden_states: torch.Tensor,
):
    """Input shape: Batch x Time x Channel"""
    bsz, tgt_len, _ = hidden_states.size()

    # get query proj
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.chunk(chunks=3, dim=-1)

    # Use unified MMEncoderAttention with automatic backend selection
    attn_output = self.attn(q, k, v)

    attn_output, _ = self.out_proj(attn_output)

    return attn_output