Autotune FlashInfer operations. FlashInfer have many implementations for the same operation, autotuning runs benchmarks for each implementation and stores the results. The results are cached transparently and future calls to FlashInfer will use the best implementation. Without autotuning, FlashInfer will rely on heuristics, which may be significantly slower.
Tuning is performed only on rank 0. The resulting cache is broadcast to every rank so all ranks dispatch the same kernel tactic.
Source code in vllm/model_executor/warmup/kernel_warmup.py
| def flashinfer_autotune(runner: "GPUModelRunner") -> None:
"""
Autotune FlashInfer operations.
FlashInfer have many implementations for the same operation,
autotuning runs benchmarks for each implementation and stores
the results. The results are cached transparently and
future calls to FlashInfer will use the best implementation.
Without autotuning, FlashInfer will rely on heuristics, which may
be significantly slower.
Tuning is performed only on rank 0. The resulting cache is broadcast
to every rank so all ranks dispatch the same kernel tactic.
"""
import vllm.utils.flashinfer as fi_utils
from vllm.distributed.parallel_state import get_world_group
if not _FLASHINFER_USE_PERSISTENT_CACHE:
with torch.inference_mode(), fi_utils.autotune():
runner._dummy_run(
num_tokens=runner.scheduler_config.max_num_batched_tokens,
skip_eplb=True,
is_profile=True,
)
get_world_group().barrier()
return
world = get_world_group()
is_leader = world.rank_in_group == 0
cache_path = _resolve_flashinfer_autotune_file(runner)
if is_leader:
logger.info("Using FlashInfer autotune cache file: %s", cache_path)
# We skip EPLB here since we don't want to record dummy metrics.
# When autotuning with number of tokens m, flashinfer will autotune
# operations for all number of tokens up to m, so we only need to
# run with the max number of tokens.
dummy_run_kwargs = dict(
num_tokens=runner.scheduler_config.max_num_batched_tokens,
skip_eplb=True,
is_profile=True,
)
with torch.inference_mode():
if is_leader:
with fi_utils.autotune(tune_mode=True, cache=str(cache_path)):
runner._dummy_run(**dummy_run_kwargs)
else:
runner._dummy_run(**dummy_run_kwargs)
# Broadcast autotune cache from rank 0 to all other ranks so every
# rank loads the same set of chosen tactics.
tune_results: bytes | None = None
if is_leader and cache_path.exists():
with open(cache_path, "rb") as f:
tune_results = f.read()
tune_results = world.broadcast_object(tune_results, src=0)
if tune_results is None:
logger.warning(
"No FlashInfer autotune cache entries found."
"Falling back to default tactics."
)
else:
if not is_leader and world.local_rank == 0:
with open(cache_path, "wb") as f:
f.write(tune_results)
world.barrier()
from flashinfer.autotuner import AutoTuner
AutoTuner.get().load_configs(str(cache_path))
logger.info(
"FlashInfer autotune cache loaded on rank %d from %s.",
world.rank_in_group,
cache_path,
)
|