class XPUWorker(Worker):
"""A XPU worker class."""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
super().__init__(
vllm_config, local_rank, rank, distributed_init_method, is_driver_worker
)
device_config = self.device_config
assert device_config.device_type == "xpu"
assert current_platform.is_xpu()
def init_device(self):
# In DP mode, XPU workers see all visible devices.
# Offset local_rank by the local DP shard.
parallel_config = self.parallel_config
if (
parallel_config.distributed_executor_backend
not in ("ray", "external_launcher")
and parallel_config.data_parallel_backend != "ray"
and parallel_config.nnodes_within_dp == 1
):
dp_local_rank = parallel_config.data_parallel_rank_local
if dp_local_rank is None:
dp_local_rank = parallel_config.data_parallel_index
tp_pp_world_size = (
parallel_config.pipeline_parallel_size
* parallel_config.tensor_parallel_size
)
self.local_rank += dp_local_rank * tp_pp_world_size
visible_device_count = torch.accelerator.device_count()
assert self.local_rank < visible_device_count, (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
assert parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({parallel_config.local_world_size}) must "
f"be less than or equal to the number of visible devices "
f"({visible_device_count})."
)
device = self.device_config.device
if (
isinstance(device, torch.device)
and device.type == "xpu"
and current_platform.is_xpu()
):
self.device = torch.device(f"xpu:{self.local_rank}")
torch.accelerator.set_device_index(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype)
torch.accelerator.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties(
self.local_rank
).total_memory
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
ENV_LOCAL_WORLD_SIZE = os.getenv(
"LOCAL_WORLD_SIZE", str(self.parallel_config.world_size)
)
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ["LOCAL_RANK"] = str(self.local_rank)
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
current_platform.dist_backend,
)
# global all_reduce needed for overall oneccl warm up
if torch.distributed.is_xccl_available():
torch.distributed.all_reduce(torch.zeros(1).xpu())
if self.use_v2_model_runner:
logger.info_once("Using V2 Model Runner")
# Set random seed.
set_random_seed(self.model_config.seed)
# Now take memory snapshot after NCCL is initialized
gc.collect()
torch.accelerator.empty_cache()
# take current memory snapshot
self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
self.requested_memory = request_memory(init_snapshot, self.cache_config)
logger.debug("worker init memory snapshot: %r", self.init_snapshot)
logger.debug(
"worker requested memory: %sGiB", format_gib(self.requested_memory)
)
# Initialize workspace manager
num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
init_workspace_manager(self.device, num_ubatches)
# Construct the model runner
model_runner = XPUModelRunnerV2 if self.use_v2_model_runner else XPUModelRunner
self.model_runner = model_runner( # type: ignore
self.vllm_config, self.device
)
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
def profile(self, is_start: bool = True, profile_prefix: str | None = None):
if self.profiler_config is None or self.profiler_config.profiler is None:
raise RuntimeError(
"Profiling is not enabled. Please set --profiler-config to enable "
"profiling. Example: "
"'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
"=YOUR_DIR_PATH_TO_DUMP_TRACE'"
)
if is_start and self.profiler is None:
from vllm.distributed.utils import get_worker_rank_suffix
rank_suffix = get_worker_rank_suffix(global_rank=self.rank)
trace_name = (
f"{profile_prefix}_{rank_suffix}" if profile_prefix else rank_suffix
)
self.profiler = TorchProfilerWrapper(
self.profiler_config,
worker_name=trace_name,
local_rank=self.local_rank,
activities=["CPU", "XPU"],
)
logger.debug("Starting torch profiler with trace name: %s", trace_name)
super().profile(is_start=is_start, profile_prefix=profile_prefix)