Skip to content

vllm.lora.model_manager

LRUCacheLoRAModelManager

Bases: LoRAModelManager

A model manager that manages multiple LoRAs with LRU cache.

Source code in vllm/lora/model_manager.py
class LRUCacheLoRAModelManager(LoRAModelManager):
    """A model manager that manages multiple LoRAs with LRU cache."""

    def __init__(
        self,
        model: nn.Module,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
        vllm_config: VllmConfig | None = None,
    ):
        super().__init__(
            model,
            max_num_seqs,
            max_num_batched_tokens,
            vocab_size,
            lora_config,
            device,
            vllm_config,
        )
        self._registered_adapters: LoRALRUCache = LoRALRUCache(
            self.capacity, self.deactivate_adapter
        )
        self._active_adapters: LoRALRUCache = LoRALRUCache(
            self.lora_slots, self._deactivate_adapter
        )

    def list_adapters(self) -> dict[int, LoRAModel]:
        """List all registered LoRAModels."""
        return dict(self._registered_adapters.cache)

    def add_adapter(self, lora: LoRAModel) -> bool:
        """Add a LoRAModel to the manager."""
        logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
        if lora.id not in self._registered_adapters:
            self._add_adapter(lora)
            was_added = True
        else:
            # We always touch to update the LRU cache order
            self._registered_adapters.touch(lora.id)
            was_added = False
        return was_added

    def activate_adapter(
        self,
        lora_id: int,
    ) -> bool:
        if (
            lora_id not in self._active_adapters
            and len(self._active_adapters) >= self.lora_slots
        ):
            self._active_adapters.remove_oldest()
        result = super().activate_adapter(lora_id)
        # We always touch to update the LRU cache order
        self._active_adapters.touch(lora_id)
        return result

    def remove_oldest_adapter(self) -> bool:
        if len(self._registered_adapters) > 0:
            self._registered_adapters.remove_oldest()
            return True
        return False

    def pin_adapter(self, lora_id: int) -> bool:
        """Pin a LoRAModel in the manager cache."""
        self._pin_lora_in_cpu_cache(lora_id)
        self._pin_lora_in_gpu_cache(lora_id)
        return True

    def _pin_lora_in_cpu_cache(self, lora_id: int):
        try:
            self._registered_adapters.pin(lora_id)
        except ValueError as err:
            raise ValueError(
                f"Pinning failed. LoRA {lora_id} is not registered."
            ) from err

    def _pin_lora_in_gpu_cache(self, lora_id: int):
        if lora_id not in self._active_adapters:
            # move lora to gpu if not already active
            self.activate_adapter(lora_id)

        self._active_adapters.pin(lora_id)

add_adapter

add_adapter(lora: LoRAModel) -> bool

Add a LoRAModel to the manager.

Source code in vllm/lora/model_manager.py
def add_adapter(self, lora: LoRAModel) -> bool:
    """Add a LoRAModel to the manager."""
    logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
    if lora.id not in self._registered_adapters:
        self._add_adapter(lora)
        was_added = True
    else:
        # We always touch to update the LRU cache order
        self._registered_adapters.touch(lora.id)
        was_added = False
    return was_added

list_adapters

list_adapters() -> dict[int, LoRAModel]

List all registered LoRAModels.

Source code in vllm/lora/model_manager.py
def list_adapters(self) -> dict[int, LoRAModel]:
    """List all registered LoRAModels."""
    return dict(self._registered_adapters.cache)

pin_adapter

pin_adapter(lora_id: int) -> bool

Pin a LoRAModel in the manager cache.

Source code in vllm/lora/model_manager.py
def pin_adapter(self, lora_id: int) -> bool:
    """Pin a LoRAModel in the manager cache."""
    self._pin_lora_in_cpu_cache(lora_id)
    self._pin_lora_in_gpu_cache(lora_id)
    return True

LoRAModelManager

A manager that manages multiple LoRA-fine-tuned models.

Source code in vllm/lora/model_manager.py
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
class LoRAModelManager:
    """A manager that manages multiple LoRA-fine-tuned models."""

    def __init__(
        self,
        model: SupportsLoRA,
        max_num_seqs: int,
        max_num_batched_tokens: int,
        vocab_size: int,
        lora_config: LoRAConfig,
        device: torch.device,
        vllm_config: VllmConfig | None = None,
    ):
        """Create a LoRAModelManager and adapter for a given model.

        Args:
            model: the model to be adapted.
            max_num_seqs: the maximum number of sequences model can run in a
                single batch.
            max_num_batched_tokens: the maximum number of tokens model can run
                in a single batch.
            vocab_size: the vocab size of the model.
            lora_config: the LoRA configuration.
        """
        self.model: SupportsLoRA = model
        self.supported_lora_modules = get_supported_lora_modules(self.model)
        assert self.supported_lora_modules, (
            f"No supported LoRA modules found in {self.model.__class__.__name__}."
        )

        self._registered_adapters: dict[int, LoRAModel] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._active_adapters: dict[int, None] = {}
        self.adapter_type = "LoRA"
        self.lora_config = lora_config
        self.device = device
        self.max_num_seqs = max_num_seqs
        assert self.capacity >= self.lora_slots
        self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
        self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
        self.vocab_size = vocab_size

        self.is_pooling_model = is_pooling_model(self.model)
        self.packed_modules: dict[str, list[str]] = {}
        self.modules: dict[str, BaseLayerWithLoRA] = {}
        # Dict instead of a set for compatibility with LRUCache.
        self._last_mapping: LoRAMapping | None = None
        is_moe = is_moe_model(self.model)
        # Whether the underlying model class declares 3D fused MoE weights.
        self._model_is_3d_moe = is_moe and self.model.is_3d_moe_weight
        # When the engine is started with enable_mixed_moe_lora_format=True
        # we force the universal 2D wrapper (FusedMoEWithLoRA) regardless of
        # the model's 3D flag, so 2D and 3D adapters can coexist.
        self._enable_mixed_moe_lora_format = (
            is_moe and lora_config.enable_mixed_moe_lora_format
        )
        self._is_3d_moe_model = (
            self._model_is_3d_moe and not self._enable_mixed_moe_lora_format
        )
        self.packed_modules_mapping = process_packed_modules_mapping(
            self.model, force_2d_moe=self._enable_mixed_moe_lora_format
        )
        self._is_non_gated_moe = is_moe and self.model.is_non_gated_moe
        self._init_punica_wrapper(max_num_batched_tokens, vllm_config)
        self._create_lora_modules()

        self.model.lora_manager = self

    def _init_punica_wrapper(
        self, max_num_batched_tokens: int, vllm_config: VllmConfig
    ) -> None:
        # Used to indicate whether the model is a multimodal model
        self.supports_mm: bool = (
            supports_multimodal(self.model)
            # In case the model only supports LoRA for
            # text modules (e.g. ChatGLM)
            and hasattr(self.model, "get_mm_mapping")
        )
        self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {}
        if self.supports_mm:
            self._maybe_init_mm(vllm_config, max_num_batched_tokens)
        else:
            llm_punica_wrapper = get_punica_wrapper(
                max_num_batched_tokens,
                max_batches=self.max_num_seqs,
                device=self.device,
                lora_config=self.lora_config,
            )

            self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY] = (
                llm_punica_wrapper
            )

    def _maybe_init_mm(
        self,
        vllm_config: VllmConfig,
        max_num_batched_tokens: int,
    ) -> None:
        mm_registry = MULTIMODAL_REGISTRY

        self.supports_tower_connector_lora = False
        self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()

        # Only one language model can be included in the model.
        assert len(self.mm_mapping.language_model) == 1

        # Language model punica wrapper
        llm_punica_wrapper = get_punica_wrapper(
            max_num_batched_tokens,
            max_batches=self.max_num_seqs,
            device=self.device,
            lora_config=self.lora_config,
        )

        lm_prefix = self.mm_mapping.language_model[0]
        self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper

        # First, determine if the model supports tower connector LoRA.
        self.supports_tower_connector_lora = self.supports_mm and hasattr(
            self.model, "get_num_mm_encoder_tokens"
        )

        # Then, handle the case where the feature is disabled in the config.
        if not self.lora_config.enable_tower_connector_lora:
            if self.supports_tower_connector_lora:
                logger.info(
                    "%s supports adding LoRA to the tower modules. If needed, "
                    "please set `enable_tower_connector_lora=True`.",
                    self.model.__class__.__name__,
                )
            self.supports_tower_connector_lora = False
            return

        # After this point, the feature is enabled in the config.
        # Now check if it's supported by the model.
        if not self.supports_tower_connector_lora:
            # Enabled but not supported: log warning and return.
            logger.warning(
                "LoRA with tower connector is enabled, but the model %s "
                "does not support it. This will be ignored.",
                self.model.__class__.__name__,
            )
            return

        # Check if initialize the language model only.
        if (
            vllm_config.model_config.multimodal_config
            and vllm_config.model_config.multimodal_config.language_model_only
        ):
            logger.warning(
                "Disabling `enable_tower_connector_lora` because the multimodal "
                "model is configured to initialize the language model only."
            )
            self.supports_tower_connector_lora = False
            return

        logger.warning(
            "LoRA for the tower and connector of multimodal models is "
            "experimental and may contain bugs. Please report any related issues on "
            "GitHub if you encounter them."
        )

        mm_budget = MultiModalBudget(vllm_config, mm_registry)
        limit_per_prompt = max(mm_budget.mm_max_items_per_prompt.values())
        num_encoder_tokens = self.model.get_num_mm_encoder_tokens(
            mm_budget.get_encoder_budget()
        )

        # Tower wrappers
        tower_punica_wrapper = get_punica_wrapper(
            num_encoder_tokens,
            max_batches=self.max_num_seqs * limit_per_prompt,
            device=self.device,
            lora_config=self.lora_config,
        )
        for prefix in self.mm_mapping.tower_model:
            self.punica_wrapper_mapping[prefix] = tower_punica_wrapper

        # Use wrapper for connector if present.
        if self.mm_mapping.connector:
            if hasattr(self.model, "get_num_mm_connector_tokens"):
                connector_tokens = self.model.get_num_mm_connector_tokens(
                    num_encoder_tokens
                )
                connector_punica_wrapper = get_punica_wrapper(
                    connector_tokens,
                    max_batches=self.max_num_seqs * limit_per_prompt,
                    device=self.device,
                    lora_config=self.lora_config,
                )
                for prefix in self.mm_mapping.connector:
                    self.punica_wrapper_mapping[prefix] = connector_punica_wrapper
            else:
                logger.warning_once(
                    "Connector LoRA support disabled: model does not implement "
                    "get_num_mm_connector_tokens(). This method is required to "
                    "determine the connector's token budget for LoRA operations."
                )

    def __len__(self) -> int:
        return len(self._registered_adapters)

    @property
    def capacity(self) -> int:
        return self.lora_config.max_cpu_loras

    @property
    def lora_slots(self) -> int:
        return self.lora_config.max_loras

    @property
    def adapter_slots(self) -> int:
        return self.lora_slots

    def activate_adapter(
        self,
        lora_id: int,
    ) -> bool:
        """Move LoRA into a GPU buffer to be used in the forward pass."""
        if lora_id in self._active_adapters:
            return False
        first_free_slot = next(
            (
                (i, lora_id)
                for i, lora_id in enumerate(self.lora_index_to_id)
                if lora_id is None
            ),
            None,
        )
        if first_free_slot is None:
            raise ValueError("No free lora slots")
        index, _ = first_free_slot
        self._active_adapters[lora_id] = None
        lora_model = self._registered_adapters[lora_id]
        logger.debug(
            "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
        )
        self.lora_index_to_id[index] = lora_model.id
        for module_name, module in self.modules.items():
            module_lora = self._get_lora_layer_weights(lora_model, module_name)
            if not module_lora:
                module.reset_lora(index)
                logger.debug(
                    "No LoRA weights found for module %s, skipping.", module_name
                )
                continue

            module.set_lora(
                index,
                module_lora.lora_a,
                module_lora.lora_b,
            )
            logger.debug("Successfully loaded LoRA weights for module %s.", module_name)
        return True

    def _deactivate_adapter(self, lora_id: int):
        try:
            index = self.lora_index_to_id.index(lora_id)
            self.lora_index_to_id[index] = None
        except ValueError:
            pass

    def _add_adapter(self, lora: LoRAModel):
        self._create_merged_loras_inplace(lora)
        self._registered_adapters[lora.id] = lora

    def pin_adapter(self, lora_id: int) -> bool:
        """Pin a LoRAModel in the manager cache."""
        raise NotImplementedError(
            "Pinning is not supported in LoRAModelManager. "
            "Use LRUCacheLoRAModelManager for pinning"
        )  # type: ignore

    def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
        # Default to the main language model wrapper
        if not (self.supports_mm and self.supports_tower_connector_lora):
            target_prefix = (
                self.mm_mapping.language_model[0]
                if self.supports_mm
                else DEFAULT_LANGUAGE_WRAPPER_KEY
            )
        elif mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
            target_prefix = self.mm_mapping.tower_model[0]
        elif mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector:
            target_prefix = self.mm_mapping.connector[0]
        else:
            target_prefix = self.mm_mapping.language_model[0]

        punica_wrapper = self._get_punica_wrapper(target_prefix)
        assert punica_wrapper is not None

        punica_wrapper.update_metadata(
            mapping,
            self.lora_index_to_id,
            self.lora_slots + 1,
            self.vocab_size,
        )

    def remove_all_adapters(self):
        """Remove all LoRAModels from the manager."""
        self._registered_adapters.clear()
        self.lora_index_to_id = [None] * self.lora_slots
        self._active_adapters.clear()

    def _create_lora_modules(self):
        def _parent_module(module_name: str) -> str:
            # module name is a dot separated name.
            # for example:
            #  - given an input 'x.y.z' return 'x.y'
            #  - given an input 'x' return ''
            return module_name.rpartition(".")[0]

        wrapped_by_id: dict[int, BaseLayerWithLoRA] = {}

        for module_name, module in self.model.named_modules(remove_duplicate=False):
            if isinstance(module, PPMissingLayer):
                continue

            if not self._match_target_modules(module_name):
                continue

            punica_wrapper = self._get_punica_wrapper(module_name)
            if punica_wrapper is None:
                logger.warning(
                    "Regarding %s, no matching PunicaWrapper "
                    "is found; %s will be ignored.",
                    self.model.__class__.__name__,
                    module_name,
                )
                continue

            # TODO: Remove this restriction
            # peft error when generating LoRA adapter with "gate" module:
            # "Target module NemotronHTopkRouter() is not supported."
            # Working LoRA adapter was created using peft with:
            # LoraConfig(target_modules="all-linear", ...)
            if self._is_non_gated_moe and module_name.endswith("mixer.gate"):
                logger.debug_once(
                    "LoRA is not supported for non-gated MoE gate module."
                    " %s will be ignored.",
                    module_name,
                )
                continue

            existing_wrapper = wrapped_by_id.get(id(module))
            if existing_wrapper is not None and "lm_head" not in module_name:
                # Same underlying module was already wrapped under another
                # path (e.g. a MoE gate held both directly on the block and
                # inside the MoE runner). The adapter targets the canonical
                # path (`mlp.gate`); rewire the alias attribute
                # (`runner.gate`) to the SAME wrapper so the forward path
                # through the alias still applies LoRA, but do NOT add a
                # second entry to self.modules — otherwise `activate_adapter`
                # would call `reset_lora` on the alias and wipe the weights
                # just set under the canonical name,  because the alias can't
                # load LoRA weights due to name mismatch.
                parent = self.model.get_submodule(_parent_module(module_name))
                # reference
                setattr(parent, module_name.rpartition(".")[-1], existing_wrapper)
                continue

            parts = module_name.split(".")[-1]
            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
            if isinstance(module, FusedMoE):
                # packed_moduled_lst is used here to just determine whether to
                # instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
                # difference between these two LoRA layers is whether the
                # LoRA weights of w1 and w3 have already been fused on disk.

                packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"]
            new_module = replace_submodule(
                self.model,
                module_name,
                from_layer(
                    module,
                    self.lora_slots,
                    self.lora_config,
                    packed_moduled_lst,
                    self.model.config,
                ),
            )
            if isinstance(new_module, BaseLayerWithLoRA):
                wrapped_by_id[id(module)] = new_module

            # (yard1): TODO make this more robust
            if "lm_head" in module_name:
                logits_processor_module_name = "logits_processor"
                parent_module = _parent_module(module_name)
                if parent_module:
                    logits_processor_module_name = (
                        f"{parent_module}.{logits_processor_module_name}"
                    )

                logits_processor_module = self.model.get_submodule(
                    logits_processor_module_name
                )

                new_module = replace_submodule(
                    self.model,
                    logits_processor_module_name,
                    from_layer_logits_processor(
                        logits_processor_module,
                        module,
                        self.lora_slots,
                        self.lora_config,
                        self.model.config,
                    ),
                )

            # Some matched modules can be unsupported by LoRA wrappers
            # (e.g. subclasses with specialized forward behavior).
            if not isinstance(new_module, BaseLayerWithLoRA):
                error_msg = (
                    "LoRA target module "
                    f"{module_name} ({type(module).__name__}) matched the "
                    "deployment configuration but could not be wrapped by any "
                    "LoRA layer implementation."
                )
                if self.lora_config.target_modules is not None:
                    raise ValueError(
                        f"{error_msg} target_modules="
                        f"{sorted(self.lora_config.target_modules)}"
                    )
                logger.warning_once("%s It will be ignored.", error_msg)
                continue
            self.register_module(module_name, new_module)

            self._register_packed_modules(module_name)
            # All lora layers share the same punica_wrapper based on reference.
            new_module.set_mapping(punica_wrapper)

    def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
        assert isinstance(module, BaseLayerWithLoRA), (
            f"Module {module_name} must be a BaseLayerWithLoRA instance, "
            f"got {type(module)}"
        )
        self.modules[module_name] = module

    @staticmethod
    def _pad_lora_pairs_to_triplets(
        loras: list[LoRALayerWeights | None],
    ) -> list[LoRALayerWeights | None]:
        """Pad LoRA weight pairs to triplets for non-gated MoE.

        For non-gated MoE, each expert has 2 entries (w1, w2) that need to be
        padded to triplets (w1, w2, None) to match pack_moe expectations.
        """
        assert len(loras) % 2 == 0, "Expected pairs of LoRA weights for non-gated MoE."
        padded: list[LoRALayerWeights | None] = []
        for i in range(0, len(loras), 2):
            padded.extend(loras[i : i + 2])
            padded.append(None)
        return padded

    def create_dummy_lora(
        self,
        lora_id: int,
        rank: int,
        embedding_modules: dict[str, str] | None = None,
    ) -> LoRAModel:
        """Create zero-initialized LoRAModel for warmup."""
        model = LoRAModel(lora_id, rank, {})
        for module_name, module in self.model.named_modules():
            if (
                not self._match_target_modules(module_name)
                or not isinstance(module, BaseLayerWithLoRA)
                or self._get_punica_wrapper(module_name) is None
            ):
                continue
            parts = module_name.split(".")
            if module_name not in self.packed_modules:
                assert embedding_modules is not None
                if parts[-1] in embedding_modules:
                    # Special-case lm_head: wrapped by LogitsProcessorWithLoRA.
                    # LoRA input dim is hidden_size, output dim is vocab size.
                    # LogitsProcessorWithLoRA handles extra vocab size directly.
                    if parts[-1] == "lm_head":
                        input_dim = module.lora_a_stacked[0].shape[-1]
                        output_dim = module.lora_b_stacked[0].shape[-2]
                    else:
                        input_dim = (
                            module.base_layer.org_vocab_size
                            if hasattr(module.base_layer, "org_vocab_size")
                            else module.base_layer.weight.shape[1]
                        )
                        output_dim = (
                            module.base_layer.embedding_dim
                            if hasattr(module.base_layer, "embedding_dim")
                            else module.base_layer.weight.shape[0]
                        )
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        input_dim,
                        output_dim,
                        rank,
                        module.lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name] = lora
                elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
                    # Case for 3D moe model
                    # w2
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w2_input_size,
                        module.w2_output_size,
                        rank * module.w2_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w2_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name] = lora
                    # w13
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.w13_input_size,
                        module.w13_output_size,
                        rank
                        * module.w13_lora_a_stacked[0].shape[1],  # rank*num_experts
                        module.w13_lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name + ".base_layer"] = lora
                else:
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name,
                        module.lora_a_stacked[0].shape[-1],
                        module.lora_b_stacked[0].shape[-2],
                        rank,
                        module.lora_a_stacked[0].dtype,
                        "cpu",
                    )
                    model.loras[module_name] = lora
            else:
                parts = module_name.split(".")
                replacements = self.packed_modules_mapping[parts[-1]]
                n_slices = getattr(module, "n_slices", len(replacements))
                if module.__class__.__name__ == "FusedMoEWithLoRA":
                    replacements = replacements[
                        : len(module.lora_a_stacked) // self.lora_slots
                    ]
                subloras: list[LoRALayerWeights | None] = []
                # HACK: overrides replacements for qkvz = qkv + z case.
                # Any better methods to handle this case?
                if n_slices != len(replacements):
                    replacements = [f"slice_{i}" for i in range(n_slices)]
                for i, r in enumerate(replacements):
                    lora = LoRALayerWeights.create_dummy_lora_weights(
                        module_name + "." + r,
                        module.lora_a_stacked[i].shape[-1],
                        module.lora_b_stacked[i].shape[-2],
                        rank,
                        module.lora_a_stacked[i].dtype,
                        "cpu",
                    )
                    subloras.append(lora)
                if module.__class__.__name__ == "FusedMoEWithLoRA":
                    # For non-gated MoE, pad subloras to 3 elements per expert
                    # to match pack_moe expectations (w1, w2, None for w3)
                    if self._is_non_gated_moe and len(subloras) > 0:
                        subloras = self._pad_lora_pairs_to_triplets(subloras)
                    lora = PackedLoRALayerWeights.pack_moe(
                        subloras, module_name, is_non_gated_moe=self._is_non_gated_moe
                    )
                else:
                    lora = PackedLoRALayerWeights.pack(subloras)
                model.loras[module_name] = lora
        return model

    def get_dummy_lora_warmup_rank(self, default_rank: int) -> int:
        """Return a dummy LoRA rank compatible with wrapped modules.

        Dummy LoRAs keep warmup memory low by using a small rank. Fully
        sharded MoE wrappers additionally require the dummy rank to be divisible
        by tensor parallel size because they shard W13 along the rank axis.
        """
        if not self.lora_config.fully_sharded_loras:
            return default_rank

        required_multiple = 1
        for module in self.modules.values():
            if not getattr(module, "fully_sharded", False):
                continue
            required_multiple = math.lcm(required_multiple, module.tp_size)

        if required_multiple == 1 or default_rank % required_multiple == 0:
            return default_rank

        adjusted_rank = (
            (default_rank + required_multiple - 1) // required_multiple
        ) * required_multiple
        if adjusted_rank > self.lora_config.max_lora_rank:
            raise ValueError(
                "Unable to choose a dummy LoRA warmup rank compatible with "
                "fully sharded MoE modules: "
                f"default_rank={default_rank}, "
                f"required_multiple={required_multiple}, "
                f"max_lora_rank={self.lora_config.max_lora_rank}"
            )
        return adjusted_rank

    def _match_target_modules(self, module_name: str) -> bool:
        """Check if a module should have LoRA applied.

        This method first checks if the module is in vLLM's supported LoRA
        modules, then applies deployment-time restrictions based on
        LoRAConfig.target_modules.

        Args:
            module_name: Full dot-separated module name (e.g.,
                "model.layers.0.self_attn.o_proj")

        Returns:
            True if LoRA should be applied to this module, False otherwise.
        """
        if not is_supported_lora_module(module_name, self.supported_lora_modules):
            return False
        return is_in_target_modules(
            module_name,
            self.lora_config.target_modules,
            self.packed_modules_mapping,
        )

    def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
        """
        Determine whether this module supports LoRA and which wrapper to use.
        """
        # For language model (early return)
        if not self.supports_mm:
            return self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY]

        # For multimodal model
        # NOTE Sort by prefix length (descending) to match the longest prefix first
        # e.g., 'visual.merger' should match 'visual.merger' instead of 'visual.'
        for prefix in sorted(self.punica_wrapper_mapping.keys(), key=len, reverse=True):
            if module_name.startswith(prefix):
                return self.punica_wrapper_mapping[prefix]

        return None

    def _register_packed_modules(self, module_full_name: str) -> None:
        parts = module_full_name.split(".")
        module_name = parts[-1]
        replacements = self.packed_modules_mapping.get(module_name, [])
        # When replacements is less than or equal to 1, it indicates that this
        # module is not a packed module.
        if len(replacements) <= 1:
            return
        prefix = ".".join(parts[:-1])
        self.packed_modules[module_full_name] = [
            prefix + "." + r if prefix else r for r in replacements
        ]

    def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
        for module_name, new_module_names in self.packed_modules.items():
            replacement_loras: list[LoRALayerWeights | None] = []
            replaced_module: set[str] = set()
            has_replacement = False
            for r in new_module_names:
                lora = self._get_lora_layer_weights(lora_model, r)
                replacement_loras.append(lora)
                if lora:
                    has_replacement = True
                    replaced_module.add(r)
            if not has_replacement:
                continue
            for i in range(len(replacement_loras)):
                if replacement_loras[i]:
                    continue
                replacement_loras[i] = None
            # HACK Temporary solution for the pool model.
            if self.is_pooling_model and not lora_model.check_lora_name(module_name):
                replaced_module_name = module_name.removeprefix("model.")
                if lora_model.check_lora_name(replaced_module_name):
                    module_name = replaced_module_name
            if module_name.endswith(".experts"):
                if self._is_non_gated_moe and len(replacement_loras) > 0:
                    replacement_loras = self._pad_lora_pairs_to_triplets(
                        replacement_loras
                    )
                lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe(
                    replacement_loras,
                    module_name,
                    is_non_gated_moe=self._is_non_gated_moe,
                )
            else:
                lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
                    replacement_loras
                )
            # Remove the modules that have been replaced.
            for module in replaced_module:
                lora_model.loras.pop(module, None)

        for lora in lora_model.loras.values():
            lora.optimize()

        for module_name, module in self.modules.items():
            if isinstance(module, FusedMoE3DWithLoRA):
                self._stack_moe_lora_weights(lora_model, module, module_name)
            elif isinstance(module, FusedMoEWithLoRA):
                # When mixed mode is enabled the universal 2D wrapper has to
                # absorb both 2D and 3D-format adapters. 3D-format adapters
                # need to be split into per-(w1, w2, w3) tensors before the
                # 2D set_lora can copy them into the stacked buffers.
                if self._enable_mixed_moe_lora_format and getattr(
                    lora_model, "is_3d_lora_weight", False
                ):
                    self._convert_3d_to_2d_moe_lora(lora_model, module, module_name)
                else:
                    self._slice_moe_lora_ep(lora_model, module, module_name)

        first_lora: LoRALayerWeights = next(iter(lora_model.loras.values()))
        assert first_lora.lora_a is not None
        if isinstance(first_lora.lora_a, list):
            lora_device = next(iter(first_lora.lora_a))
        else:
            lora_device = first_lora.lora_a.device
        # Execute pin_memory after LoRA weight merging, mainly because:
        # 1. Some MoE models have a large number of LoRA weights. If we
        # perform # pin_memory immediately after loading weights, the
        # overhead is significant.
        # 2. The weight packing above (e.g., pack_moe) may invalidate the
        # pin_memory allocation, so we execute it after packing.

        pin_memory = str(lora_device) == "cpu" and is_pin_memory_available()
        if pin_memory:
            for lora in lora_model.loras.values():
                if isinstance(lora.lora_a, list):
                    for index in range(len(lora.lora_a)):
                        if lora.lora_a[index] is None:
                            continue
                        lora.lora_a[index] = lora.lora_a[index].pin_memory()
                        lora.lora_b[index] = lora.lora_b[index].pin_memory()
                else:
                    lora.lora_a = lora.lora_a.pin_memory()
                    lora.lora_b = lora.lora_b.pin_memory()

    def _stack_moe_lora_weights(
        self, lora_model: LoRAModel, module: FusedMoE3DWithLoRA, module_name: str
    ):
        module_lora = self._get_lora_layer_weights(lora_model, module_name)

        # Note (gnovack) - If MOE lora weights are not split into
        # num_experts chunks, we split them here
        if module_lora and torch.is_tensor(module_lora.lora_a):
            # Handle PEFT file format where experts.base_layer is the
            # gate_up_proj and experts is the down_proj
            gate_up_proj_lora = self._get_lora_layer_weights(
                lora_model, module_name + ".base_layer"
            )
            down_proj_lora = module_lora
            # FIXME Edge case where LoRA is not added to gate_up_proj
            # or down_proj
            assert gate_up_proj_lora is not None
            assert down_proj_lora is not None
            if self._is_3d_moe_model:
                local_num_experts = module.w13_lora_a_stacked[0].shape[1]
                # The checkpoint holds weights for all global experts, but
                # each EP rank owns only local_num_experts. Reshape against
                # the adapter's actual expert count, then slice this rank's
                # owned expert range before it gets copied into the local
                # stacked buffer. For non-EP (local == global) this is a
                # no-op slice.
                global_num_experts = module.base_layer.global_num_experts
                ep_rank = module.base_layer.ep_rank
                expert_start = ep_rank * local_num_experts
                expert_end = expert_start + local_num_experts

                # (num_experts,rank,input_size)
                gate_up_proj_lora.lora_a = gate_up_proj_lora.lora_a.reshape(
                    global_num_experts, -1, gate_up_proj_lora.lora_a.shape[-1]
                )[expert_start:expert_end].contiguous()
                down_proj_lora.lora_a = down_proj_lora.lora_a.reshape(
                    global_num_experts, -1, down_proj_lora.lora_a.shape[-1]
                )[expert_start:expert_end].contiguous()

                # (output_size,rank,num_experts)
                gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.reshape(
                    gate_up_proj_lora.lora_b.shape[0], -1, global_num_experts
                )[..., expert_start:expert_end]
                down_proj_lora.lora_b = down_proj_lora.lora_b.reshape(
                    down_proj_lora.lora_b.shape[0], -1, global_num_experts
                )[..., expert_start:expert_end]

                # (num_experts,output_size,rank)
                gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.permute(
                    2, 0, 1
                ).contiguous()
                down_proj_lora.lora_b = down_proj_lora.lora_b.permute(
                    2, 0, 1
                ).contiguous()

                module_lora.lora_a = [
                    gate_up_proj_lora.lora_a,
                    down_proj_lora.lora_a,
                ]
                module_lora.lora_b = [
                    gate_up_proj_lora.lora_b,
                    down_proj_lora.lora_b,
                ]
            else:
                # Some 3D MoE models haven't added the `is_3d_moe_weight`
                # attribute yet, so fallback here
                num_experts = module_lora.lora_a.shape[0] // module_lora.rank

                gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
                up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)

                gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
                    num_experts, dim=-1
                )
                up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
                    num_experts, dim=-1
                )

                down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
                down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)

                lora_a = []
                lora_b = []
                for i in range(num_experts):
                    lora_a.append(gate_proj_a[i])
                    lora_a.append(down_proj_a[i])
                    lora_a.append(up_proj_a[i])

                    lora_b.append(gate_proj_b[i])
                    lora_b.append(down_proj_b[i])
                    lora_b.append(up_proj_b[i])

                module_lora.lora_a = lora_a
                module_lora.lora_b = lora_b

    def _convert_3d_to_2d_moe_lora(
        self,
        lora_model: LoRAModel,
        module: FusedMoEWithLoRA,
        module_name: str,
    ) -> None:
        """Convert a 3D-format MoE LoRA checkpoint into the 2D pack layout
        that `FusedMoEWithLoRA.set_lora` expects.

        On disk the 3D PEFT layout stores two flat tensor pairs per layer:
          - `{module}.base_layer.lora_{A,B}`: gate_up_proj, with shapes
                `(rank * num_experts, hidden)` / `(intermediate * 2,
                rank * num_experts)`
          - `{module}.lora_{A,B}`: down_proj, with shapes
                `(rank * num_experts, intermediate)` / `(hidden,
                rank * num_experts)`
        The 2D wrapper expects three stacked per-expert tensors,
        `[w1, w2, w3]`, with `(num_experts, rank, in)` for lora_a and
        `(num_experts, out, rank)` for lora_b. In the 3D layout w1
        (gate_proj) and w3 (up_proj) share the rank-r intermediate
        representation, so both halves use the same lora_a tensor.

        Only invoked when `enable_mixed_moe_lora_format=True` and the
        source LoRARequest declares `is_3d_lora_weight=True`.
        """
        gate_up_proj_lora = self._get_lora_layer_weights(
            lora_model, module_name + ".base_layer"
        )
        down_proj_lora = self._get_lora_layer_weights(lora_model, module_name)
        if gate_up_proj_lora is None or down_proj_lora is None:
            # Either the adapter omits the experts entirely or the file
            # layout differs from what this path supports; leave the entry
            # untouched so set_lora can raise a clear error if needed.
            return

        local_num_experts = module.base_layer.local_num_experts
        global_num_experts = module.base_layer.global_num_experts
        ep_rank = module.base_layer.ep_rank
        expert_start = ep_rank * local_num_experts
        expert_end = expert_start + local_num_experts

        # Reshape and EP-slice into per-expert 3D tensors. This mirrors
        # `_stack_moe_lora_weights`; for non-EP runs the slice is a no-op.
        gate_up_a = gate_up_proj_lora.lora_a.reshape(
            global_num_experts, -1, gate_up_proj_lora.lora_a.shape[-1]
        )[expert_start:expert_end].contiguous()
        gate_up_b = (
            gate_up_proj_lora.lora_b.reshape(
                gate_up_proj_lora.lora_b.shape[0], -1, global_num_experts
            )[..., expert_start:expert_end]
            .permute(2, 0, 1)
            .contiguous()
        )
        down_a = down_proj_lora.lora_a.reshape(
            global_num_experts, -1, down_proj_lora.lora_a.shape[-1]
        )[expert_start:expert_end].contiguous()
        down_b = (
            down_proj_lora.lora_b.reshape(
                down_proj_lora.lora_b.shape[0], -1, global_num_experts
            )[..., expert_start:expert_end]
            .permute(2, 0, 1)
            .contiguous()
        )

        # Split the fused gate_up_proj output dim into separate w1 / w3
        # halves. GPT-OSS interleaves them along the output dim, all other
        # 3D MoE checkpoints we know about concatenate them.
        intermediate_x2 = gate_up_b.shape[1]
        if intermediate_x2 % 2 != 0:
            raise ValueError(
                "Expected gate_up_proj LoRA-B output dim to be 2 * intermediate, "
                f"got {intermediate_x2}."
            )
        intermediate = intermediate_x2 // 2
        base_arch = self.model.config.architectures[0]
        if base_arch == "GptOssForCausalLM":
            w1_b = gate_up_b[:, ::2, :].contiguous()
            w3_b = gate_up_b[:, 1::2, :].contiguous()
        else:
            w1_b = gate_up_b[:, :intermediate, :].contiguous()
            w3_b = gate_up_b[:, intermediate:, :].contiguous()

        # In the 3D layout w1 and w3 share the same rank-r mid
        # representation, so they reuse the same lora_a tensor. The 2D
        # wrapper's set_lora copies whatever it gets here into independent
        # per-slice buffers, so the sharing is purely a CPU-side memory
        # optimization and does not affect numerics.
        down_proj_lora.lora_a = [gate_up_a, down_a, gate_up_a]
        down_proj_lora.lora_b = [w1_b, down_b, w3_b]
        # Drop the redundant base_layer entry to avoid double pin_memory
        # and to keep the activation path looking up only the wrapper key.
        lora_model.loras.pop(module_name + ".base_layer", None)

    def _slice_moe_lora_ep(
        self,
        lora_model: LoRAModel,
        module: FusedMoEWithLoRA,
        module_name: str,
    ) -> None:
        """Slice the cached LoRA tensors down to this rank's local experts.

        The 2D MoE checkpoint enters as a list of per-(w1/w2/w3) tensors of
        shape (num_experts, rank, in) / (num_experts, out, rank). When EP
        is active each rank only owns local_num_experts; without this slice
        the CPU LoRAModel keeps the full global weight and set_lora has to
        re-slice on every activation.
        """
        if not module.base_layer.use_ep:
            return
        module_lora = self._get_lora_layer_weights(lora_model, module_name)
        if module_lora is None or not isinstance(module_lora.lora_a, list):
            return

        local_num_experts = module.base_layer.local_num_experts
        global_num_experts = module.base_layer.global_num_experts
        ep_rank = module.base_layer.ep_rank
        expert_start = ep_rank * local_num_experts
        expert_end = expert_start + local_num_experts

        new_lora_a: list[torch.Tensor | None] = []
        new_lora_b: list[torch.Tensor | None] = []
        for a, b in zip(module_lora.lora_a, module_lora.lora_b):
            if a is not None and b is not None and a.shape[0] == global_num_experts:
                a = a[expert_start:expert_end].contiguous()
                b = b[expert_start:expert_end].contiguous()
            new_lora_a.append(a)
            new_lora_b.append(b)
        module_lora.lora_a = new_lora_a
        module_lora.lora_b = new_lora_b

    def _get_lora_layer_weights(
        self, lora_model: LoRAModel, module_name: str
    ) -> LoRALayerWeights | None:
        org_module_name = module_name
        if self.is_pooling_model and not lora_model.check_lora_name(module_name):
            # If it's a pool model, and the layer name is not found,
            # remove the prefix 'model.' and search again.
            module_name = module_name.removeprefix("model.")
            if lora_model.check_lora_name(module_name):
                org_module_name = module_name
                logger.info_once(
                    "For the pool model, successfully loaded the LoRA weights "
                    "after removing the prefix 'model.'."
                )
        return lora_model.get_lora(org_module_name)

    def deactivate_adapter(self, adapter_id: int) -> bool:
        if adapter_id not in self._active_adapters:
            return False
        self._deactivate_adapter(adapter_id)
        self._active_adapters.pop(adapter_id, None)
        return True

    def add_adapter(self, adapter: LoRAModel) -> bool:
        logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
        if adapter.id in self._registered_adapters:
            return False
        if len(self._registered_adapters) >= self.capacity:
            raise RuntimeError("No free adapter slots.")
        self._add_adapter(adapter)
        return True

    def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
        if self._last_mapping != mapping:
            self._set_adapter_mapping(mapping)
            self._last_mapping = mapping

    def remove_adapter(self, adapter_id: int) -> bool:
        self.deactivate_adapter(adapter_id)
        if adapter_id not in self._registered_adapters:
            return False
        self._registered_adapters.pop(adapter_id, None)
        return True

    def list_adapters(self) -> dict[int, LoRAModel]:
        return dict(self._registered_adapters)

    def get_adapter(self, adapter_id: int) -> LoRAModel | None:
        return self._registered_adapters.get(adapter_id)

__init__

__init__(
    model: SupportsLoRA,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    device: device,
    vllm_config: VllmConfig | None = None,
)

Create a LoRAModelManager and adapter for a given model.

Parameters:

Name Type Description Default
model SupportsLoRA

the model to be adapted.

required
max_num_seqs int

the maximum number of sequences model can run in a single batch.

required
max_num_batched_tokens int

the maximum number of tokens model can run in a single batch.

required
vocab_size int

the vocab size of the model.

required
lora_config LoRAConfig

the LoRA configuration.

required
Source code in vllm/lora/model_manager.py
def __init__(
    self,
    model: SupportsLoRA,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    device: torch.device,
    vllm_config: VllmConfig | None = None,
):
    """Create a LoRAModelManager and adapter for a given model.

    Args:
        model: the model to be adapted.
        max_num_seqs: the maximum number of sequences model can run in a
            single batch.
        max_num_batched_tokens: the maximum number of tokens model can run
            in a single batch.
        vocab_size: the vocab size of the model.
        lora_config: the LoRA configuration.
    """
    self.model: SupportsLoRA = model
    self.supported_lora_modules = get_supported_lora_modules(self.model)
    assert self.supported_lora_modules, (
        f"No supported LoRA modules found in {self.model.__class__.__name__}."
    )

    self._registered_adapters: dict[int, LoRAModel] = {}
    # Dict instead of a set for compatibility with LRUCache.
    self._active_adapters: dict[int, None] = {}
    self.adapter_type = "LoRA"
    self.lora_config = lora_config
    self.device = device
    self.max_num_seqs = max_num_seqs
    assert self.capacity >= self.lora_slots
    self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
    self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
    self.vocab_size = vocab_size

    self.is_pooling_model = is_pooling_model(self.model)
    self.packed_modules: dict[str, list[str]] = {}
    self.modules: dict[str, BaseLayerWithLoRA] = {}
    # Dict instead of a set for compatibility with LRUCache.
    self._last_mapping: LoRAMapping | None = None
    is_moe = is_moe_model(self.model)
    # Whether the underlying model class declares 3D fused MoE weights.
    self._model_is_3d_moe = is_moe and self.model.is_3d_moe_weight
    # When the engine is started with enable_mixed_moe_lora_format=True
    # we force the universal 2D wrapper (FusedMoEWithLoRA) regardless of
    # the model's 3D flag, so 2D and 3D adapters can coexist.
    self._enable_mixed_moe_lora_format = (
        is_moe and lora_config.enable_mixed_moe_lora_format
    )
    self._is_3d_moe_model = (
        self._model_is_3d_moe and not self._enable_mixed_moe_lora_format
    )
    self.packed_modules_mapping = process_packed_modules_mapping(
        self.model, force_2d_moe=self._enable_mixed_moe_lora_format
    )
    self._is_non_gated_moe = is_moe and self.model.is_non_gated_moe
    self._init_punica_wrapper(max_num_batched_tokens, vllm_config)
    self._create_lora_modules()

    self.model.lora_manager = self

_convert_3d_to_2d_moe_lora

_convert_3d_to_2d_moe_lora(
    lora_model: LoRAModel,
    module: FusedMoEWithLoRA,
    module_name: str,
) -> None

Convert a 3D-format MoE LoRA checkpoint into the 2D pack layout that FusedMoEWithLoRA.set_lora expects.

On disk the 3D PEFT layout stores two flat tensor pairs per layer
  • {module}.base_layer.lora_{A,B}: gate_up_proj, with shapes (rank * num_experts, hidden) / (intermediate * 2, rank * num_experts)
  • {module}.lora_{A,B}: down_proj, with shapes (rank * num_experts, intermediate) / (hidden, rank * num_experts)

The 2D wrapper expects three stacked per-expert tensors, [w1, w2, w3], with (num_experts, rank, in) for lora_a and (num_experts, out, rank) for lora_b. In the 3D layout w1 (gate_proj) and w3 (up_proj) share the rank-r intermediate representation, so both halves use the same lora_a tensor.

Only invoked when enable_mixed_moe_lora_format=True and the source LoRARequest declares is_3d_lora_weight=True.

Source code in vllm/lora/model_manager.py
def _convert_3d_to_2d_moe_lora(
    self,
    lora_model: LoRAModel,
    module: FusedMoEWithLoRA,
    module_name: str,
) -> None:
    """Convert a 3D-format MoE LoRA checkpoint into the 2D pack layout
    that `FusedMoEWithLoRA.set_lora` expects.

    On disk the 3D PEFT layout stores two flat tensor pairs per layer:
      - `{module}.base_layer.lora_{A,B}`: gate_up_proj, with shapes
            `(rank * num_experts, hidden)` / `(intermediate * 2,
            rank * num_experts)`
      - `{module}.lora_{A,B}`: down_proj, with shapes
            `(rank * num_experts, intermediate)` / `(hidden,
            rank * num_experts)`
    The 2D wrapper expects three stacked per-expert tensors,
    `[w1, w2, w3]`, with `(num_experts, rank, in)` for lora_a and
    `(num_experts, out, rank)` for lora_b. In the 3D layout w1
    (gate_proj) and w3 (up_proj) share the rank-r intermediate
    representation, so both halves use the same lora_a tensor.

    Only invoked when `enable_mixed_moe_lora_format=True` and the
    source LoRARequest declares `is_3d_lora_weight=True`.
    """
    gate_up_proj_lora = self._get_lora_layer_weights(
        lora_model, module_name + ".base_layer"
    )
    down_proj_lora = self._get_lora_layer_weights(lora_model, module_name)
    if gate_up_proj_lora is None or down_proj_lora is None:
        # Either the adapter omits the experts entirely or the file
        # layout differs from what this path supports; leave the entry
        # untouched so set_lora can raise a clear error if needed.
        return

    local_num_experts = module.base_layer.local_num_experts
    global_num_experts = module.base_layer.global_num_experts
    ep_rank = module.base_layer.ep_rank
    expert_start = ep_rank * local_num_experts
    expert_end = expert_start + local_num_experts

    # Reshape and EP-slice into per-expert 3D tensors. This mirrors
    # `_stack_moe_lora_weights`; for non-EP runs the slice is a no-op.
    gate_up_a = gate_up_proj_lora.lora_a.reshape(
        global_num_experts, -1, gate_up_proj_lora.lora_a.shape[-1]
    )[expert_start:expert_end].contiguous()
    gate_up_b = (
        gate_up_proj_lora.lora_b.reshape(
            gate_up_proj_lora.lora_b.shape[0], -1, global_num_experts
        )[..., expert_start:expert_end]
        .permute(2, 0, 1)
        .contiguous()
    )
    down_a = down_proj_lora.lora_a.reshape(
        global_num_experts, -1, down_proj_lora.lora_a.shape[-1]
    )[expert_start:expert_end].contiguous()
    down_b = (
        down_proj_lora.lora_b.reshape(
            down_proj_lora.lora_b.shape[0], -1, global_num_experts
        )[..., expert_start:expert_end]
        .permute(2, 0, 1)
        .contiguous()
    )

    # Split the fused gate_up_proj output dim into separate w1 / w3
    # halves. GPT-OSS interleaves them along the output dim, all other
    # 3D MoE checkpoints we know about concatenate them.
    intermediate_x2 = gate_up_b.shape[1]
    if intermediate_x2 % 2 != 0:
        raise ValueError(
            "Expected gate_up_proj LoRA-B output dim to be 2 * intermediate, "
            f"got {intermediate_x2}."
        )
    intermediate = intermediate_x2 // 2
    base_arch = self.model.config.architectures[0]
    if base_arch == "GptOssForCausalLM":
        w1_b = gate_up_b[:, ::2, :].contiguous()
        w3_b = gate_up_b[:, 1::2, :].contiguous()
    else:
        w1_b = gate_up_b[:, :intermediate, :].contiguous()
        w3_b = gate_up_b[:, intermediate:, :].contiguous()

    # In the 3D layout w1 and w3 share the same rank-r mid
    # representation, so they reuse the same lora_a tensor. The 2D
    # wrapper's set_lora copies whatever it gets here into independent
    # per-slice buffers, so the sharing is purely a CPU-side memory
    # optimization and does not affect numerics.
    down_proj_lora.lora_a = [gate_up_a, down_a, gate_up_a]
    down_proj_lora.lora_b = [w1_b, down_b, w3_b]
    # Drop the redundant base_layer entry to avoid double pin_memory
    # and to keep the activation path looking up only the wrapper key.
    lora_model.loras.pop(module_name + ".base_layer", None)

_get_punica_wrapper

_get_punica_wrapper(
    module_name: str,
) -> PunicaWrapperBase | None

Determine whether this module supports LoRA and which wrapper to use.

Source code in vllm/lora/model_manager.py
def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
    """
    Determine whether this module supports LoRA and which wrapper to use.
    """
    # For language model (early return)
    if not self.supports_mm:
        return self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY]

    # For multimodal model
    # NOTE Sort by prefix length (descending) to match the longest prefix first
    # e.g., 'visual.merger' should match 'visual.merger' instead of 'visual.'
    for prefix in sorted(self.punica_wrapper_mapping.keys(), key=len, reverse=True):
        if module_name.startswith(prefix):
            return self.punica_wrapper_mapping[prefix]

    return None

_match_target_modules

_match_target_modules(module_name: str) -> bool

Check if a module should have LoRA applied.

This method first checks if the module is in vLLM's supported LoRA modules, then applies deployment-time restrictions based on LoRAConfig.target_modules.

Parameters:

Name Type Description Default
module_name str

Full dot-separated module name (e.g., "model.layers.0.self_attn.o_proj")

required

Returns:

Type Description
bool

True if LoRA should be applied to this module, False otherwise.

Source code in vllm/lora/model_manager.py
def _match_target_modules(self, module_name: str) -> bool:
    """Check if a module should have LoRA applied.

    This method first checks if the module is in vLLM's supported LoRA
    modules, then applies deployment-time restrictions based on
    LoRAConfig.target_modules.

    Args:
        module_name: Full dot-separated module name (e.g.,
            "model.layers.0.self_attn.o_proj")

    Returns:
        True if LoRA should be applied to this module, False otherwise.
    """
    if not is_supported_lora_module(module_name, self.supported_lora_modules):
        return False
    return is_in_target_modules(
        module_name,
        self.lora_config.target_modules,
        self.packed_modules_mapping,
    )

_pad_lora_pairs_to_triplets staticmethod

_pad_lora_pairs_to_triplets(
    loras: list[LoRALayerWeights | None],
) -> list[LoRALayerWeights | None]

Pad LoRA weight pairs to triplets for non-gated MoE.

For non-gated MoE, each expert has 2 entries (w1, w2) that need to be padded to triplets (w1, w2, None) to match pack_moe expectations.

Source code in vllm/lora/model_manager.py
@staticmethod
def _pad_lora_pairs_to_triplets(
    loras: list[LoRALayerWeights | None],
) -> list[LoRALayerWeights | None]:
    """Pad LoRA weight pairs to triplets for non-gated MoE.

    For non-gated MoE, each expert has 2 entries (w1, w2) that need to be
    padded to triplets (w1, w2, None) to match pack_moe expectations.
    """
    assert len(loras) % 2 == 0, "Expected pairs of LoRA weights for non-gated MoE."
    padded: list[LoRALayerWeights | None] = []
    for i in range(0, len(loras), 2):
        padded.extend(loras[i : i + 2])
        padded.append(None)
    return padded

_slice_moe_lora_ep

_slice_moe_lora_ep(
    lora_model: LoRAModel,
    module: FusedMoEWithLoRA,
    module_name: str,
) -> None

Slice the cached LoRA tensors down to this rank's local experts.

The 2D MoE checkpoint enters as a list of per-(w1/w2/w3) tensors of shape (num_experts, rank, in) / (num_experts, out, rank). When EP is active each rank only owns local_num_experts; without this slice the CPU LoRAModel keeps the full global weight and set_lora has to re-slice on every activation.

Source code in vllm/lora/model_manager.py
def _slice_moe_lora_ep(
    self,
    lora_model: LoRAModel,
    module: FusedMoEWithLoRA,
    module_name: str,
) -> None:
    """Slice the cached LoRA tensors down to this rank's local experts.

    The 2D MoE checkpoint enters as a list of per-(w1/w2/w3) tensors of
    shape (num_experts, rank, in) / (num_experts, out, rank). When EP
    is active each rank only owns local_num_experts; without this slice
    the CPU LoRAModel keeps the full global weight and set_lora has to
    re-slice on every activation.
    """
    if not module.base_layer.use_ep:
        return
    module_lora = self._get_lora_layer_weights(lora_model, module_name)
    if module_lora is None or not isinstance(module_lora.lora_a, list):
        return

    local_num_experts = module.base_layer.local_num_experts
    global_num_experts = module.base_layer.global_num_experts
    ep_rank = module.base_layer.ep_rank
    expert_start = ep_rank * local_num_experts
    expert_end = expert_start + local_num_experts

    new_lora_a: list[torch.Tensor | None] = []
    new_lora_b: list[torch.Tensor | None] = []
    for a, b in zip(module_lora.lora_a, module_lora.lora_b):
        if a is not None and b is not None and a.shape[0] == global_num_experts:
            a = a[expert_start:expert_end].contiguous()
            b = b[expert_start:expert_end].contiguous()
        new_lora_a.append(a)
        new_lora_b.append(b)
    module_lora.lora_a = new_lora_a
    module_lora.lora_b = new_lora_b

activate_adapter

activate_adapter(lora_id: int) -> bool

Move LoRA into a GPU buffer to be used in the forward pass.

Source code in vllm/lora/model_manager.py
def activate_adapter(
    self,
    lora_id: int,
) -> bool:
    """Move LoRA into a GPU buffer to be used in the forward pass."""
    if lora_id in self._active_adapters:
        return False
    first_free_slot = next(
        (
            (i, lora_id)
            for i, lora_id in enumerate(self.lora_index_to_id)
            if lora_id is None
        ),
        None,
    )
    if first_free_slot is None:
        raise ValueError("No free lora slots")
    index, _ = first_free_slot
    self._active_adapters[lora_id] = None
    lora_model = self._registered_adapters[lora_id]
    logger.debug(
        "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
    )
    self.lora_index_to_id[index] = lora_model.id
    for module_name, module in self.modules.items():
        module_lora = self._get_lora_layer_weights(lora_model, module_name)
        if not module_lora:
            module.reset_lora(index)
            logger.debug(
                "No LoRA weights found for module %s, skipping.", module_name
            )
            continue

        module.set_lora(
            index,
            module_lora.lora_a,
            module_lora.lora_b,
        )
        logger.debug("Successfully loaded LoRA weights for module %s.", module_name)
    return True

create_dummy_lora

create_dummy_lora(
    lora_id: int,
    rank: int,
    embedding_modules: dict[str, str] | None = None,
) -> LoRAModel

Create zero-initialized LoRAModel for warmup.

Source code in vllm/lora/model_manager.py
def create_dummy_lora(
    self,
    lora_id: int,
    rank: int,
    embedding_modules: dict[str, str] | None = None,
) -> LoRAModel:
    """Create zero-initialized LoRAModel for warmup."""
    model = LoRAModel(lora_id, rank, {})
    for module_name, module in self.model.named_modules():
        if (
            not self._match_target_modules(module_name)
            or not isinstance(module, BaseLayerWithLoRA)
            or self._get_punica_wrapper(module_name) is None
        ):
            continue
        parts = module_name.split(".")
        if module_name not in self.packed_modules:
            assert embedding_modules is not None
            if parts[-1] in embedding_modules:
                # Special-case lm_head: wrapped by LogitsProcessorWithLoRA.
                # LoRA input dim is hidden_size, output dim is vocab size.
                # LogitsProcessorWithLoRA handles extra vocab size directly.
                if parts[-1] == "lm_head":
                    input_dim = module.lora_a_stacked[0].shape[-1]
                    output_dim = module.lora_b_stacked[0].shape[-2]
                else:
                    input_dim = (
                        module.base_layer.org_vocab_size
                        if hasattr(module.base_layer, "org_vocab_size")
                        else module.base_layer.weight.shape[1]
                    )
                    output_dim = (
                        module.base_layer.embedding_dim
                        if hasattr(module.base_layer, "embedding_dim")
                        else module.base_layer.weight.shape[0]
                    )
                lora = LoRALayerWeights.create_dummy_lora_weights(
                    module_name,
                    input_dim,
                    output_dim,
                    rank,
                    module.lora_a_stacked[0].dtype,
                    "cpu",
                )
                model.loras[module_name] = lora
            elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
                # Case for 3D moe model
                # w2
                lora = LoRALayerWeights.create_dummy_lora_weights(
                    module_name,
                    module.w2_input_size,
                    module.w2_output_size,
                    rank * module.w2_lora_a_stacked[0].shape[1],  # rank*num_experts
                    module.w2_lora_a_stacked[0].dtype,
                    "cpu",
                )
                model.loras[module_name] = lora
                # w13
                lora = LoRALayerWeights.create_dummy_lora_weights(
                    module_name,
                    module.w13_input_size,
                    module.w13_output_size,
                    rank
                    * module.w13_lora_a_stacked[0].shape[1],  # rank*num_experts
                    module.w13_lora_a_stacked[0].dtype,
                    "cpu",
                )
                model.loras[module_name + ".base_layer"] = lora
            else:
                lora = LoRALayerWeights.create_dummy_lora_weights(
                    module_name,
                    module.lora_a_stacked[0].shape[-1],
                    module.lora_b_stacked[0].shape[-2],
                    rank,
                    module.lora_a_stacked[0].dtype,
                    "cpu",
                )
                model.loras[module_name] = lora
        else:
            parts = module_name.split(".")
            replacements = self.packed_modules_mapping[parts[-1]]
            n_slices = getattr(module, "n_slices", len(replacements))
            if module.__class__.__name__ == "FusedMoEWithLoRA":
                replacements = replacements[
                    : len(module.lora_a_stacked) // self.lora_slots
                ]
            subloras: list[LoRALayerWeights | None] = []
            # HACK: overrides replacements for qkvz = qkv + z case.
            # Any better methods to handle this case?
            if n_slices != len(replacements):
                replacements = [f"slice_{i}" for i in range(n_slices)]
            for i, r in enumerate(replacements):
                lora = LoRALayerWeights.create_dummy_lora_weights(
                    module_name + "." + r,
                    module.lora_a_stacked[i].shape[-1],
                    module.lora_b_stacked[i].shape[-2],
                    rank,
                    module.lora_a_stacked[i].dtype,
                    "cpu",
                )
                subloras.append(lora)
            if module.__class__.__name__ == "FusedMoEWithLoRA":
                # For non-gated MoE, pad subloras to 3 elements per expert
                # to match pack_moe expectations (w1, w2, None for w3)
                if self._is_non_gated_moe and len(subloras) > 0:
                    subloras = self._pad_lora_pairs_to_triplets(subloras)
                lora = PackedLoRALayerWeights.pack_moe(
                    subloras, module_name, is_non_gated_moe=self._is_non_gated_moe
                )
            else:
                lora = PackedLoRALayerWeights.pack(subloras)
            model.loras[module_name] = lora
    return model

get_dummy_lora_warmup_rank

get_dummy_lora_warmup_rank(default_rank: int) -> int

Return a dummy LoRA rank compatible with wrapped modules.

Dummy LoRAs keep warmup memory low by using a small rank. Fully sharded MoE wrappers additionally require the dummy rank to be divisible by tensor parallel size because they shard W13 along the rank axis.

Source code in vllm/lora/model_manager.py
def get_dummy_lora_warmup_rank(self, default_rank: int) -> int:
    """Return a dummy LoRA rank compatible with wrapped modules.

    Dummy LoRAs keep warmup memory low by using a small rank. Fully
    sharded MoE wrappers additionally require the dummy rank to be divisible
    by tensor parallel size because they shard W13 along the rank axis.
    """
    if not self.lora_config.fully_sharded_loras:
        return default_rank

    required_multiple = 1
    for module in self.modules.values():
        if not getattr(module, "fully_sharded", False):
            continue
        required_multiple = math.lcm(required_multiple, module.tp_size)

    if required_multiple == 1 or default_rank % required_multiple == 0:
        return default_rank

    adjusted_rank = (
        (default_rank + required_multiple - 1) // required_multiple
    ) * required_multiple
    if adjusted_rank > self.lora_config.max_lora_rank:
        raise ValueError(
            "Unable to choose a dummy LoRA warmup rank compatible with "
            "fully sharded MoE modules: "
            f"default_rank={default_rank}, "
            f"required_multiple={required_multiple}, "
            f"max_lora_rank={self.lora_config.max_lora_rank}"
        )
    return adjusted_rank

pin_adapter

pin_adapter(lora_id: int) -> bool

Pin a LoRAModel in the manager cache.

Source code in vllm/lora/model_manager.py
def pin_adapter(self, lora_id: int) -> bool:
    """Pin a LoRAModel in the manager cache."""
    raise NotImplementedError(
        "Pinning is not supported in LoRAModelManager. "
        "Use LRUCacheLoRAModelManager for pinning"
    )  # type: ignore

remove_all_adapters

remove_all_adapters()

Remove all LoRAModels from the manager.

Source code in vllm/lora/model_manager.py
def remove_all_adapters(self):
    """Remove all LoRAModels from the manager."""
    self._registered_adapters.clear()
    self.lora_index_to_id = [None] * self.lora_slots
    self._active_adapters.clear()

create_lora_manager

create_lora_manager(
    model: Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    vllm_config: VllmConfig,
    device: device,
    lora_manager_cls: type[
        LoRAModelManager
    ] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager

Create a LoRA adapter for a given model.

Source code in vllm/lora/model_manager.py
def create_lora_manager(
    model: nn.Module,
    max_num_seqs: int,
    max_num_batched_tokens: int,
    vocab_size: int,
    lora_config: LoRAConfig,
    vllm_config: VllmConfig,
    device: torch.device,
    lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
    **kwargs,
) -> LoRAModelManager:
    """Create a LoRA adapter for a given model."""
    if not isinstance(model, SupportsLoRA):
        raise ValueError(f"Model {type(model)} is not supported for LoRA.")
    lora_manager = lora_manager_cls(
        model=model,
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        vocab_size=vocab_size,
        lora_config=lora_config,
        vllm_config=vllm_config,
        device=device,
        **kwargs,
    )
    return lora_manager