Skip to content

vllm.v1.worker.gpu.spec_decode.eagle.cudagraph

DecodeEagleCudaGraphManager

Bases: EagleCudaGraphManagerBase

Eagle CudaGraphManager for decode draft generation, building its own attention metadata from scratch.

Source code in vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
class DecodeEagleCudaGraphManager(EagleCudaGraphManagerBase):
    """Eagle CudaGraphManager for decode draft generation, building its own
    attention metadata from scratch."""

    def capture(
        self,
        forward_fn: Callable,
        model_state: ModelState,
        input_buffers: InputBuffers,
        block_tables: BlockTables,
        attn_groups: list[list[AttentionGroup]],
        kv_cache_config: KVCacheConfig,
        progress_bar_desc: str = "Capturing CUDA graphs",
    ) -> None:
        def create_forward_fn(
            desc: BatchExecutionDescriptor,
        ) -> tuple[Callable[[CUDAGraphMode], None], CapturedAttentionState]:
            num_tokens = desc.num_tokens
            num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
            num_tokens_across_dp = (
                torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
                if self.dp_size > 1
                else None
            )
            attn_state = prepare_inputs_to_capture(
                num_reqs,
                num_tokens,
                model_state,
                input_buffers,
                block_tables,
                attn_groups,
                kv_cache_config,
            )
            attn_metadata, slot_mappings = attn_state

            fwd = lambda cg_mode: forward_fn(
                num_reqs,
                num_tokens,
                attn_metadata,
                slot_mappings,
                num_tokens_across_dp,
                cg_mode,
            )
            return fwd, attn_state

        super().capture(create_forward_fn, progress_bar_desc)

EagleCudaGraphManagerBase

Bases: CudaGraphManager

Base CudaGraphManager for Eagle with a dedicated graph pool.

Source code in vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
class EagleCudaGraphManagerBase(CudaGraphManager):
    """Base CudaGraphManager for Eagle with a dedicated graph pool."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        cudagraph_mode: CUDAGraphMode,
        decode_query_len: int,
    ):
        super().__init__(vllm_config, device, cudagraph_mode, decode_query_len)

        # Use a dedicated pool for Eagle to avoid memory overlap with the main
        # model's cudagraph. The base class uses a shared global pool, but Eagle's
        # internal allocations (e.g., gumbel_sample temporaries) can conflict with
        # the main model's allocations when sharing the same pool.
        if cudagraph_mode:
            self.pool = torch.cuda.graph_pool_handle()

PrefillEagleCudaGraphManager

Bases: EagleCudaGraphManagerBase

Eagle CudaGraphManager for prefill, using pre-built attention states from the target model's capture.

Source code in vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
class PrefillEagleCudaGraphManager(EagleCudaGraphManagerBase):
    """Eagle CudaGraphManager for prefill, using pre-built attention states
    from the target model's capture."""

    def capture(
        self,
        forward_fn: Callable,
        full_cg_attn_states: dict[BatchExecutionDescriptor, CapturedAttentionState],
        progress_bar_desc: str = "Capturing CUDA graphs",
    ) -> None:
        def create_forward_fn(
            desc: BatchExecutionDescriptor,
        ) -> tuple[Callable[[CUDAGraphMode], None], CapturedAttentionState]:
            num_tokens = desc.num_tokens
            num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
            num_tokens_across_dp = (
                torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
                if self.dp_size > 1
                else None
            )
            attn_state = full_cg_attn_states[desc]
            attn_metadata, slot_mappings = attn_state
            fwd = lambda cg_mode: forward_fn(
                num_reqs,
                num_tokens,
                attn_metadata,
                slot_mappings,
                num_tokens_across_dp,
                cg_mode,
            )
            return fwd, attn_state

        super().capture(create_forward_fn, progress_bar_desc)