vllm.model_executor.layers.fused_moe.routed_experts_capturer ¶
RoutedExpertsCapturer ¶
Worker-side capturer for routed experts, lives on GPU.
Layer-level hooks call :meth:capture from inside the forward pass with the per-layer topk_ids tensor. The tensor is sliced to the tokens owned by this DP rank and written into a preallocated device buffer. At the end of the step, :class:GPUModelRunner reads the device buffer, issues a D2H copy into a pinned CPU buffer, and hands the result to the scheduler via :class:RoutedExpertsLists.
The device / pinned-CPU transit buffers use torch.int32 (not a narrow uint8/uint16 sized by num_experts). This keeps the SP all-gather path free of dtype casts, matches the router's native topk_ids indices dtype more closely, and costs only a few MB per worker (max_num_batched_tokens * num_layers * top_k * 4 bytes). The scheduler-side slot buffer (RoutedExpertsManager.routed_experts_by_slot) still uses the narrow dtype -- numpy fancy-index assignment in store_batch narrows the data on the way in.
Invariants
- One instance per worker; shape is fixed at init and covers the worst-case step (
max_num_batched_tokenstokens). - :meth:
clear_bufferis called at the start of every step, so unused slots stay zero. device_buffer.dtypeistorch.int32.
Source code in vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | |
capture ¶
Capture expert routing decisions for a specific layer.
Under data parallelism, topk_ids may have three different batch layouts depending on where the DP combine happens and whether Sequence Parallelism (SP) is active for the MoE layer: - n == total (naive dispatch): all DP ranks' tokens are concatenated before routing; we slice out this rank's span using the cumulative per-rank counts. - n == token_num_per_dp (modular-kernel path): DP combine happens inside quant_method.apply; select_experts only ever sees this rank's tokens, so we take the whole tensor. - n == ceil(token_num_per_dp / tp_size) (SP + modular-kernel path): tokens were split along dim=0 across the TP group by _sequence_parallel_context (moe_runner_base.py:_sequence_parallel_context), so each TP rank only sees its shard. We all-gather along dim=0 to reconstruct this DP rank's full routing tensor. SP pads with ceil-div (see _compute_sp_num_tokens in forward_context.py), so the gathered tensor may contain a few trailing padding rows which are trimmed by the downstream [:token_num_per_dp] slice.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_id | int | The layer index. | required |
topk_ids | Tensor | Tensor of shape (batch_size, num_routed_experts). | required |
Source code in vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | |
clear_buffer ¶
Zero the device buffer. Called at the start of every step so slots belonging to finished / preempted tokens don't leak into the next step.
Source code in vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
get_device_buffer ¶
get_device_buffer() -> Tensor
Return the underlying device buffer so the model runner can issue the D2H copy. The tensor is shared; callers must either clone or fully drain it before the next forward pass runs :meth:clear_buffer.
Source code in vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
RoutedExpertsManager ¶
Scheduler-side slot-indexed buffer for routed experts.
Lives on CPU in the scheduler process. Each slot corresponds to block_id * block_size + offset_in_block where block_id is drawn from the physical KV-cache block pool, so routing data is tied to physical blocks and naturally survives preemption for prefix-cached blocks (prefix hits re-expose the same slots).
Data flow per step
- Worker D2Hs its device capture buffer into :class:
RoutedExpertsListsand returns it via :class:ModelRunnerOutput. - Scheduler calls :meth:
store_batchwith that step's(routing_data, slot_mapping)— a single CPU->CPU fancy-index assign, ~few MB per step. - On request completion / abort / preemption, the scheduler calls :meth:
getwith the request's block IDs to recover the full per-token routing.
Memory: routed_experts_by_slot is sized for the whole block pool (num_blocks * block_size slots). For large block pools this can reach multiple GB; see the init log for the exact size.
Source code in vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 | |
get ¶
Read routed experts data for a completed / preempted request.
Reconstructs a per-token slot_mapping from the request's block IDs and returns the routing slice. Because numpy fancy indexing returns a copy (not a view), the returned ndarray is safe to hold across subsequent :meth:store_batch calls — do not replace the fancy index with a slice without re-verifying.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
block_ids | list[int] | Block IDs from the attention KV-cache group. | required |
num_tokens | int | Number of tokens that have gone through a forward pass and therefore have routing data written to their slots (typically | required |
token_start | int | Skip the first | 0 |
Returns:
| Type | Description |
|---|---|
ndarray | Array of shape (num_tokens - token_start, num_layers, |
ndarray | num_experts_per_tok). |
Source code in vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
store_batch ¶
Persist one step's routed experts into the slot buffer.
Equivalent to slot_buffer[slot_mapping] = data; numpy fancy indexing handles repeated / out-of-order indices. Called once per scheduler step in update_from_output.
Source code in vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
_get_num_experts_per_tok ¶
_get_num_experts_per_tok(hf_config) -> int
Resolve the per-token expert count from the HF config.
Different model families store this under different attribute names (e.g. num_experts_per_tok for DeepSeek, top_k_experts for Gemma 4).
Source code in vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
get_num_experts ¶
get_num_experts(hf_config) -> int
Resolve num_experts across HuggingFace config naming conventions.
Different MoE model families expose this under different keys
num_experts: Mixtral, Qwen2-MoE, Qwen3-MoEn_routed_experts: DeepSeek-V2/V3num_local_experts: Mixtral (older exports)