Skip to content

RLHF IPC Fsdp Ep

Source https://github.com/vllm-project/vllm/blob/main/examples/rl/rlhf_ipc_fsdp_ep.py.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
RLHF with FSDP2 training and vLLM expert-parallel inference using **CUDA IPC**
weight transfer and **packed** tensors.

Layout (4 GPUs, TP=1, DP=4, EP):
  * One Ray placement group per GPU.
  * Each PG holds one FSDP training worker and one vLLM ``LLM`` instance
    (sync API) using fractional GPUs so both fit on the same device.
  * The 4 ``LLM`` instances form a DP group via env-var-based SPMD
    coordination (``VLLM_DP_RANK``, ``VLLM_DP_SIZE``, etc.), the same
    mechanism used by ``examples/offline_inference/data_parallel.py``.
  * A ``DataParallelInferenceEngine`` actor spawns all 4 LLM actors,
    waits for initialization, and orchestrates generation / weight-sync.

Uses the built-in ``ray`` send_mode: each FSDP worker calls
``trainer_send_weights`` targeting its colocated LLM actor.

This example was run on 4xH100.
"""

from __future__ import annotations

import os
from dataclasses import asdict

import ray
import torch
import torch.distributed as dist
from huggingface_hub import snapshot_download
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torch.distributed._tensor import DTensor
from torch.distributed.fsdp import fully_shard
from transformers import AutoModelForCausalLM

from vllm import LLM, SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.ipc_engine import (
    IPCTrainerSendWeightsArgs,
    IPCWeightTransferEngine,
    IPCWeightTransferInitInfo,
)
from vllm.utils.network_utils import get_ip, get_open_port

TRAIN_GPU_FRACTION = float(os.environ.get("RLHF_IPC_TRAIN_GPU_FRACTION", "0.42"))
VLLM_GPU_FRACTION = float(os.environ.get("RLHF_IPC_VLLM_GPU_FRACTION", "0.42"))

MODEL_NAME = "Qwen/Qwen3-30B-A3B"

FSDP_WORLD_SIZE = 4
INFERENCE_TP_SIZE = 1
INFERENCE_DP_SIZE = 4


class MyLLM(LLM):
    """LLM subclass that configures DP env vars for SPMD coordination."""

    def __init__(
        self,
        *args,
        dp_rank: int = 0,
        dp_size: int = 1,
        dp_master_ip: str = "127.0.0.1",
        dp_master_port: int = 0,
        **kwargs,
    ):
        os.environ.pop("CUDA_VISIBLE_DEVICES", None)
        os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(VLLM_GPU_FRACTION)
        os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0"
        os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"

        os.environ["VLLM_DP_RANK"] = str(dp_rank)
        os.environ["VLLM_DP_RANK_LOCAL"] = str(dp_rank)
        os.environ["VLLM_DP_SIZE"] = str(dp_size)
        os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
        os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)

        super().__init__(*args, **kwargs)

    def ready(self):
        return True


@ray.remote(num_cpus=0, num_gpus=TRAIN_GPU_FRACTION)
class FSDPTrainWorker:
    """One FSDP2 worker per GPU; colocated with vLLM DP rank via placement group."""

    def __init__(
        self,
        model_name: str,
        rank: int,
        fsdp_world_size: int,
        fsdp_master_addr: str,
        fsdp_master_port: int,
    ):
        self.rank = rank

        os.environ["MASTER_ADDR"] = fsdp_master_addr
        os.environ["MASTER_PORT"] = str(fsdp_master_port)

        dist.init_process_group(backend="nccl", rank=rank, world_size=fsdp_world_size)
        torch.accelerator.set_device_index(0)

        model = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=torch.bfloat16
        )

        self.weight_names = [n for n, _ in model.named_parameters()]
        self.weight_dtype_names = [
            str(p.dtype).split(".")[-1] for _, p in model.named_parameters()
        ]
        self.weight_shapes = [list(p.shape) for _, p in model.named_parameters()]

        for layer in model.model.layers:
            fully_shard(layer)
        fully_shard(model)

        self.model = model

    def get_rank(self):
        return self.rank

    def get_weight_metadata(self):
        return self.weight_names, self.weight_dtype_names, self.weight_shapes

    def gather_and_broadcast_weights_ipc(self, llm_handle, packed: bool = True):
        """All-gather full params; all ranks create IPC handles, rank 0 sends.

        All ranks must call trainer_send_weights so they participate in the
        all_gather_object collective inside _all_gather_and_merge_handles.
        Only rank 0 actually sends the payload to vLLM (gated by _is_rank_zero).
        """

        def _full_param_iter():
            # HF's Qwen3MoeExperts (and other recent HF MoE impls) packs
            # all experts into two fused 3-D tensors per layer:
            #   experts.gate_up_proj  shape (E, 2*I, H)
            #   experts.down_proj     shape (E, H, I)
            # vLLM's Qwen3MoE load_weights still expects the older
            # per-expert HF layout (experts.<i>.gate_proj.weight,
            # experts.<i>.up_proj.weight, experts.<i>.down_proj.weight),
            # so we un-fuse on the fly. Split order matches HF's forward:
            #   gate, up = linear(x, gate_up_proj[i]).chunk(2, dim=-1)
            # → rows [:I] of gate_up_proj[i] are gate, rows [I:] are up.
            params = self.model.state_dict()
            for name in list(params.keys()):
                param = params.pop(name)
                if isinstance(param, DTensor):
                    tensor = param.full_tensor().detach().contiguous()
                else:
                    tensor = param.detach().contiguous()
                del param

                if name.endswith(".experts.gate_up_proj") and tensor.dim() == 3:
                    prefix = name[: -len(".gate_up_proj")]
                    num_experts, two_inter, _ = tensor.shape
                    inter = two_inter // 2
                    for i in range(num_experts):
                        expert = tensor[i]
                        yield (
                            f"{prefix}.{i}.gate_proj.weight",
                            expert[:inter].contiguous(),
                        )
                        yield (
                            f"{prefix}.{i}.up_proj.weight",
                            expert[inter:].contiguous(),
                        )
                    del tensor
                elif name.endswith(".experts.down_proj") and tensor.dim() == 3:
                    prefix = name[: -len(".down_proj")]
                    num_experts = tensor.shape[0]
                    for i in range(num_experts):
                        yield (
                            f"{prefix}.{i}.down_proj.weight",
                            tensor[i].contiguous(),
                        )
                    del tensor
                else:
                    yield name, tensor

        trainer_args = IPCTrainerSendWeightsArgs(
            send_mode="ray",
            llm_handle=llm_handle,
            packed=packed,
            packed_buffer_size_bytes=1024 * 1024 * 1024,  # 1 GB
        )
        IPCWeightTransferEngine.trainer_send_weights(
            iterator=_full_param_iter(),
            trainer_args=trainer_args,
        )


@ray.remote(num_cpus=1)
class DataParallelInferenceEngine:
    """Manages a pool of DP-sharded vLLM LLM actors.

    Spawns one MyLLM actor per placement group, waits for all engines to
    finish initializing, and exposes generation / weight-sync helpers.
    """

    def __init__(
        self,
        model: str,
        pgs: list,
        dp_master_ip: str,
        dp_master_port: int,
    ):
        dp_size = len(pgs)
        self.llm_actors = []
        for r in range(dp_size):
            sched = PlacementGroupSchedulingStrategy(
                placement_group=pgs[r],
                placement_group_capture_child_tasks=True,
            )
            actor = (
                ray.remote(num_cpus=0, num_gpus=0)(MyLLM)
                .options(scheduling_strategy=sched)
                .remote(
                    model=model,
                    enforce_eager=True,
                    tensor_parallel_size=INFERENCE_TP_SIZE,
                    distributed_executor_backend="ray",
                    enable_expert_parallel=True,
                    gpu_memory_utilization=0.35,
                    weight_transfer_config=WeightTransferConfig(backend="ipc"),
                    enable_sleep_mode=True,
                    load_format="dummy",
                    dp_rank=r,
                    dp_size=dp_size,
                    dp_master_ip=dp_master_ip,
                    dp_master_port=dp_master_port,
                )
            )
            self.llm_actors.append(actor)

        ray.get([actor.ready.remote() for actor in self.llm_actors])

    def get_llm_actors(self):
        return self.llm_actors

    def generate(self, prompts: list[str], sampling_params):
        """Distribute prompts round-robin across DP ranks and collect results."""
        dp_size = len(self.llm_actors)
        per_rank: list[list[str]] = [[] for _ in range(dp_size)]
        indices: list[list[int]] = [[] for _ in range(dp_size)]

        for i, prompt in enumerate(prompts):
            rank = i % dp_size
            per_rank[rank].append(prompt)
            indices[rank].append(i)

        refs = [
            actor.generate.remote(per_rank[r], sampling_params)
            for r, actor in enumerate(self.llm_actors)
            if per_rank[r]
        ]
        all_outputs = ray.get(refs)

        ordered = [None] * len(prompts)
        rank_idx = 0
        for r in range(dp_size):
            if per_rank[r]:
                for local_i, orig_i in enumerate(indices[r]):
                    ordered[orig_i] = all_outputs[rank_idx][local_i]
                rank_idx += 1
        return ordered

    def init_weight_transfer(self):
        ray.get(
            [
                actor.init_weight_transfer_engine.remote(
                    dict(init_info=asdict(IPCWeightTransferInitInfo()))
                )
                for actor in self.llm_actors
            ]
        )

    def start_weight_update(self, is_checkpoint_format: bool = True):
        ray.get(
            [
                actor.start_weight_update.remote(
                    is_checkpoint_format=is_checkpoint_format
                )
                for actor in self.llm_actors
            ]
        )

    def finish_weight_update(self):
        ray.get([actor.finish_weight_update.remote() for actor in self.llm_actors])

    def sleep(self, level: int = 0):
        ray.get([actor.sleep.remote(level=level) for actor in self.llm_actors])

    def wake_up(self, tags: list[str] | None = None):
        ray.get([actor.wake_up.remote(tags=tags) for actor in self.llm_actors])


def main():
    ray.init(
        runtime_env={
            "env_vars": {
                "VLLM_ALLOW_INSECURE_SERIALIZATION": "1",
            }
        }
    )

    assert TRAIN_GPU_FRACTION + VLLM_GPU_FRACTION <= 1.0, (
        "Train + vLLM GPU fractions must sum to at most 1.0 per bundle."
    )

    local_model_path = snapshot_download(MODEL_NAME)
    print(f"[init] Model downloaded to {local_model_path}")

    fsdp_master_addr = get_ip()
    fsdp_master_port = get_open_port()
    dp_master_port = get_open_port()
    dp_master_ip = get_ip()

    # Create one placement group per DP rank (one GPU each).
    pgs = []
    for _ in range(INFERENCE_DP_SIZE):
        pg = placement_group([{"GPU": 1, "CPU": 1}])
        pgs.append(pg)
    ray.get([pg.ready() for pg in pgs])
    print(f"[init] {len(pgs)} placement groups ready.")

    # Launch FSDP training workers, one per PG.
    scheduling = [
        PlacementGroupSchedulingStrategy(
            placement_group=pgs[r],
            placement_group_capture_child_tasks=True,
        )
        for r in range(FSDP_WORLD_SIZE)
    ]

    fsdp_workers = [
        FSDPTrainWorker.options(scheduling_strategy=scheduling[r]).remote(
            local_model_path,
            r,
            FSDP_WORLD_SIZE,
            fsdp_master_addr,
            fsdp_master_port,
        )
        for r in range(FSDP_WORLD_SIZE)
    ]
    ray.get([w.get_rank.remote() for w in fsdp_workers])
    print(f"[init] {FSDP_WORLD_SIZE} FSDP workers ready.")

    # Launch DP inference engine (spawns and initializes all LLM actors).
    inference_engine = DataParallelInferenceEngine.remote(
        model=local_model_path,
        pgs=pgs,
        dp_master_ip=dp_master_ip,
        dp_master_port=dp_master_port,
    )
    llm_actors = ray.get(inference_engine.get_llm_actors.remote())
    print(f"[init] {INFERENCE_DP_SIZE} LLM actors ready.")

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0)

    print("[generate] Generating with dummy weights...")
    outputs = ray.get(inference_engine.generate.remote(prompts, sampling_params))
    print("-" * 60)
    print("BEFORE weight sync (dummy weights):")
    print("-" * 60)
    for output in outputs:
        print(f"Prompt: {output.prompt!r}")
        print(f"Generated: {output.outputs[0].text!r}")
        print("-" * 60)

    # --- Weight transfer ---
    print("[transfer] Initializing IPC weight transfer...")
    ray.get(inference_engine.init_weight_transfer.remote())

    # Two-phase sleep/wake pattern:
    # 1. sleep(level=1) — offload weights to CPU, discard KV cache
    # 2. wake_up(tags=["weights"]) — bring weights back to GPU (KV cache still free)
    # 3. IPC weight transfer — overwrite weights, plenty of room without KV cache
    # 4. wake_up(tags=["kv_cache"]) — re-allocate KV cache for inference
    print("[sync] Sleeping engines (offload weights + free KV cache)...")
    ray.get(inference_engine.sleep.remote(level=1))

    print("[sync] Waking weights (KV cache stays free)...")
    ray.get(inference_engine.wake_up.remote(tags=["weights"]))

    print("[sync] Starting weight update...")
    ray.get(inference_engine.start_weight_update.remote(is_checkpoint_format=True))

    print("[sync] Packed IPC transfer FSDP → vLLM...")
    ray.get(
        [
            w.gather_and_broadcast_weights_ipc.remote(llm_actors, packed=True)
            for w in fsdp_workers
        ]
    )

    ray.get(inference_engine.finish_weight_update.remote())
    print("[sync] Weight transfer complete.")

    print("[sync] Waking KV cache + scheduling...")
    ray.get(inference_engine.wake_up.remote(tags=["kv_cache", "scheduling"]))

    print("[generate] Generating with synced weights...")
    outputs_updated = ray.get(
        inference_engine.generate.remote(prompts, sampling_params)
    )
    print("-" * 60)
    print("AFTER weight sync (real weights):")
    print("-" * 60)
    for output in outputs_updated:
        print(f"Prompt: {output.prompt!r}")
        print(f"Generated: {output.outputs[0].text!r}")
        print("-" * 60)


if __name__ == "__main__":
    main()