Skip to content

vllm.v1.worker.mamba_utils

cleanup_mamba_state_idx

cleanup_mamba_state_idx(
    scheduler_output: SchedulerOutput,
    mamba_state_idx: dict[str, int],
) -> None

Pop stale mamba_state_idx entries for finished/preempted/resumed reqs.

Force-preempted requests (e.g., during reset_prefix_cache / KV cache flush) appear in resumed_req_ids without a corresponding entry in preempted_req_ids, leaving stale entries that can point to block indices beyond the new (smaller) block allocation.

Source code in vllm/v1/worker/mamba_utils.py
def cleanup_mamba_state_idx(
    scheduler_output: SchedulerOutput,
    mamba_state_idx: dict[str, int],
) -> None:
    """Pop stale `mamba_state_idx` entries for finished/preempted/resumed reqs.

    Force-preempted requests (e.g., during reset_prefix_cache / KV cache
    flush) appear in resumed_req_ids without a corresponding entry in
    preempted_req_ids, leaving stale entries that can point to block
    indices beyond the new (smaller) block allocation.
    """
    finished_req_ids = scheduler_output.finished_req_ids
    preempted_req_ids = scheduler_output.preempted_req_ids or set()
    resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids
    for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
        mamba_state_idx.pop(req_id, None)

postprocess_mamba

postprocess_mamba(
    scheduler_output: SchedulerOutput,
    kv_cache_config: KVCacheConfig,
    cache_config: CacheConfig,
    input_batch: GPUInputBatch,
    requests: dict[str, CachedRequestState],
    mamba_state_idx: dict[str, int],
    num_spec_tokens: int,
    num_reqs: int,
    *,
    forward_context: dict[str, Any] | None = None,
    mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...]
    | None = None,
    copy_bufs: MambaCopyBuffers | None = None,
)

Post-model-execute mamba prefix-caching bookkeeping. Dispatched by cache_config.mamba_cache_mode: - "align": if a block is converted from partial to full this step, copy the running state into the new full block. - "all" + num_spec_tokens > 0: record per-request the block index of the last token scheduled this step, so the next step can anchor its in-place writes when accepted drafts leave the sequence at a non-block-aligned position.

Source code in vllm/v1/worker/mamba_utils.py
def postprocess_mamba(
    scheduler_output: SchedulerOutput,
    kv_cache_config: KVCacheConfig,
    cache_config: CacheConfig,
    input_batch: GPUInputBatch,
    requests: dict[str, CachedRequestState],
    mamba_state_idx: dict[str, int],
    num_spec_tokens: int,
    num_reqs: int,
    *,
    forward_context: dict[str, Any] | None = None,
    mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...] | None = None,
    copy_bufs: MambaCopyBuffers | None = None,
):
    """
    Post-model-execute mamba prefix-caching bookkeeping. Dispatched by
    cache_config.mamba_cache_mode:
      - "align": if a block is converted from partial to full this step,
        copy the running state into the new full block.
      - "all" + num_spec_tokens > 0: record per-request the block index of
        the last token scheduled this step, so the next step can anchor
        its in-place writes when accepted drafts leave the sequence at a
        non-block-aligned position.
    """
    if cache_config.mamba_cache_mode == "align":
        assert forward_context is not None
        assert mamba_state_copy_funcs is not None
        assert copy_bufs is not None
        num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens
        scheduled_spec_decode_tokens_dict = (
            scheduler_output.scheduled_spec_decode_tokens
        )
        num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
        mamba_group_ids = copy_bufs.mamba_group_ids
        mamba_spec = copy_bufs.mamba_spec
        copy_bufs.offset = 0
        for i, req_id in enumerate(input_batch.req_ids):
            req_state = requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, []))
            num_scheduled_tokens = num_scheduled_tokens_dict[req_id]
            num_accepted_tokens = num_accepted_tokens_cpu[i]
            num_tokens_running_state = (
                num_computed_tokens + num_scheduled_tokens - num_draft_tokens
            )
            new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1
            aligned_new_computed_tokens = (
                new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size
            )
            # TODO: how to ensure all blocks that cache_blocks called are cached here?
            if aligned_new_computed_tokens >= num_tokens_running_state:
                accept_token_bias = (
                    aligned_new_computed_tokens - num_tokens_running_state
                )
                src_block_idx = mamba_state_idx[req_id]
                dest_block_idx = (
                    aligned_new_computed_tokens // mamba_spec.block_size - 1
                )
                collect_mamba_copy_meta(
                    copy_bufs,
                    kv_cache_config,
                    mamba_state_copy_funcs,
                    mamba_group_ids,
                    src_block_idx,
                    dest_block_idx,
                    accept_token_bias,
                    req_state,
                    forward_context,
                )
                if src_block_idx == dest_block_idx:
                    num_accepted_tokens_cpu[i] = 1
        do_mamba_copy_block(copy_bufs)
    elif cache_config.mamba_cache_mode == "all" and num_spec_tokens > 0:
        _, mamba_spec = get_mamba_groups(kv_cache_config)
        block_size = mamba_spec.block_size
        full_decode_len = 1 + num_spec_tokens
        scheduled = scheduler_output.num_scheduled_tokens
        for req_id in input_batch.req_ids[:num_reqs]:
            num_query = scheduled.get(req_id, 0)
            if num_query == full_decode_len:
                req = requests[req_id]
                seq_len = req.num_computed_tokens + num_query
                mamba_state_idx[req_id] = max(0, (seq_len - 1) // block_size)
            else:
                mamba_state_idx.pop(req_id, None)

preprocess_mamba

preprocess_mamba(
    scheduler_output: SchedulerOutput,
    kv_cache_config: KVCacheConfig,
    cache_config: CacheConfig,
    mamba_state_idx: dict[str, int],
    input_batch: GPUInputBatch,
    requests: dict[str, CachedRequestState],
    forward_context: dict[str, Any],
    mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
    copy_bufs: MambaCopyBuffers,
)

Copy the mamba state of previous step to the last (1 + num_speculative_blocks) block.

Source code in vllm/v1/worker/mamba_utils.py
def preprocess_mamba(
    scheduler_output: SchedulerOutput,
    kv_cache_config: KVCacheConfig,
    cache_config: CacheConfig,
    mamba_state_idx: dict[str, int],
    input_batch: GPUInputBatch,
    requests: dict[str, CachedRequestState],
    forward_context: dict[str, Any],
    mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
    copy_bufs: MambaCopyBuffers,
):
    """
    Copy the mamba state of previous step to the last
    (1 + num_speculative_blocks) block.
    """
    mamba_group_ids = copy_bufs.mamba_group_ids
    mamba_spec = copy_bufs.mamba_spec
    num_speculative_blocks = mamba_spec.num_speculative_blocks
    # TODO(Chen): we need to optimize this function a lot
    assert cache_config.enable_prefix_caching
    block_size = mamba_spec.block_size
    cleanup_mamba_state_idx(scheduler_output, mamba_state_idx)

    copy_bufs.offset = 0
    for i, req_id in enumerate(input_batch.req_ids):
        req_state = requests[req_id]
        prev_state_idx = mamba_state_idx.get(req_id)
        if prev_state_idx is None:
            # new / resumed request, no previous state
            # if num_computed_tokens is 0, prev_state_idx will be -1
            prev_state_idx = (req_state.num_computed_tokens - 1) // block_size

        num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
        num_blocks: int = (
            cdiv(req_state.num_computed_tokens + num_scheduled_tokens, block_size)
            + num_speculative_blocks
        )

        # We always save the current running state at the last
        # (1 + num_speculative_blocks) block.
        # A corner case worth mention here: assume we have block_size = 4 and
        # num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft
        # tokens [draft 1, draft 2]. Then we will have:
        # Block 0: [A, B, C, draft 1]
        # Block 1: [draft 2, TOFILL, TOFILL, TOFILL]
        # Block 2: speculative block
        # Block 3: speculative block
        # And use block 1 to save the running state.
        curr_state_idx = num_blocks - 1 - num_speculative_blocks
        mamba_state_idx[req_id] = curr_state_idx
        if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
            collect_mamba_copy_meta(
                copy_bufs,
                kv_cache_config,
                mamba_state_copy_funcs,
                mamba_group_ids,
                prev_state_idx,
                curr_state_idx,
                input_batch.num_accepted_tokens_cpu[i] - 1,
                req_state,
                forward_context,
            )
            input_batch.num_accepted_tokens_cpu[i] = 1
    do_mamba_copy_block(copy_bufs)