Skip to content

vllm.compilation.decorators

_should_ignore_torch_compile

_should_ignore_torch_compile(cls: type[_T]) -> bool

Check if the class should be ignored for torch.compile.

Source code in vllm/compilation/decorators.py
def _should_ignore_torch_compile(cls: type[_T]) -> bool:
    """
    Check if the class should be ignored for torch.compile.
    """
    return getattr(cls, IGNORE_COMPILE_KEY, False)

_support_torch_compile

_support_torch_compile(
    cls: type[_T],
    dynamic_arg_dims: dict[
        str, int | list[int] | dict[int, str]
    ],
    mark_unbacked_dims: dict[str, int | list[int]]
    | None = None,
    enable_if: Callable[[VllmConfig], bool] | None = None,
    is_encoder: bool = False,
) -> type[_T]

Internal implementation of support_torch_compile decorator.

Source code in vllm/compilation/decorators.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
def _support_torch_compile(
    cls: type[_T],
    dynamic_arg_dims: dict[str, int | list[int] | dict[int, str]],
    mark_unbacked_dims: dict[str, int | list[int]] | None = None,
    enable_if: Callable[[VllmConfig], bool] | None = None,
    is_encoder: bool = False,
) -> type[_T]:
    """Internal implementation of support_torch_compile decorator."""

    if TorchCompileWithNoGuardsWrapper in cls.__bases__:
        # support decorating multiple times
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
    #  other than TorchCompileWithNoGuardsWrapper
    cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)

    old_init = cls.__init__

    setattr(cls, IGNORE_COMPILE_KEY, False)

    def __init__(
        self: _T,
        *args,
        vllm_config: VllmConfig | None = None,
        prefix: str = "",
        **kwargs: Any,
    ) -> None:
        if vllm_config is None:
            vllm_config = get_current_vllm_config()

        # NOTE: to support multimodal models (such as encoder),
        # we may not have vllm_config so we may need to patch it
        sig = inspect.signature(old_init)
        # Check that any positional arguments match the old_init method signature
        annotations = [p.annotation for p in sig.parameters.values()]
        for arg, annotation in zip(args, annotations):
            if annotation is inspect._empty:
                continue
            if not isinstance(arg, annotation):
                init = f"'{type(self).__name__}.__init__'"
                arg_type = f"'{type(arg).__name__}'"
                raise TypeError(
                    f"{init} received a positional argument of type {arg_type}, "
                    "but no parameter of that type was found in the method signature. "
                    f"Please either annotate {init} or pass it as a keyword argument."
                )
        if "vllm_config" in sig.parameters:
            kwargs["vllm_config"] = vllm_config
        if "prefix" in sig.parameters:
            kwargs["prefix"] = prefix
        old_init(self, *args, **kwargs)

        self.vllm_config = vllm_config
        self.compilation_config = self.vllm_config.compilation_config
        enable_compile = enable_if is None or enable_if(vllm_config)
        # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
        # will handle the compilation, so we don't need to do anything here.
        self.do_not_compile = (
            self.compilation_config.mode
            in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
            or _should_ignore_torch_compile(self.__class__)
            or not enable_compile
        )
        if self.do_not_compile:
            return

        self._dynamic_arg_dims = dynamic_arg_dims

        self.was_aot_compile_fn_loaded_from_disk = False
        compilation_counter.num_models_seen += 1
        self.compiled = False

        # Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
        TorchCompileWithNoGuardsWrapper.__init__(
            self,
            compile_prefix=cls.__name__ if is_encoder else "",
            is_encoder=is_encoder,
        )

    cls.__init__ = __init__

    def _mark_dynamic_inputs(
        mod: type[_T], ds_type: DynamicShapesType, *args: Any, **kwargs: Any
    ) -> None:
        def mark_dynamic(
            arg: torch.Tensor, dim_shape_pairs: list[tuple[int, str | None]]
        ) -> None:
            if ds_type == DynamicShapesType.UNBACKED:
                if is_torch_equal_or_newer("2.10.0"):
                    for dim, shape_id in dim_shape_pairs:
                        if shape_id is not None:
                            if not _SUPPORTS_SHAPE_ID:
                                raise RuntimeError(
                                    f"shape_id='{shape_id}' requires PyTorch >= 2.11.0"
                                )
                            torch._dynamo.decorators.mark_unbacked(
                                arg,
                                dim,
                                hint_override=arg.size()[dim],
                                shape_id=shape_id,
                            )
                        else:
                            torch._dynamo.decorators.mark_unbacked(
                                arg,
                                dim,
                                hint_override=arg.size()[dim],
                            )
                else:
                    # For older versions, we can't use hint_override or shape_id
                    dims = [dim for dim, _ in dim_shape_pairs]
                    torch._dynamo.decorators.mark_unbacked(arg, dims)
            else:
                dims = [dim for dim, _ in dim_shape_pairs]
                torch._dynamo.mark_dynamic(arg, dims)

        sig = inspect.signature(mod.__class__.forward)  # type: ignore[attr-defined]
        bound_args = sig.bind(mod, *args, **kwargs)
        bound_args.apply_defaults()

        # Normalize dynamic_arg_dims to dict[str, dict[int, str | None]]
        normalized_dims: dict[str, dict[int, str | None]] = {}
        for k, v in dynamic_arg_dims.items():
            if isinstance(v, dict):
                normalized_dims[k] = {dim: shape_id for dim, shape_id in v.items()}
            elif isinstance(v, int):
                normalized_dims[k] = {v: None}
            else:
                normalized_dims[k] = {d: None for d in v}

        for k, dim_to_shape_id in normalized_dims.items():
            arg = bound_args.arguments.get(k)

            if arg is not None:
                dims = list(dim_to_shape_id.keys())

                if isinstance(arg, torch.Tensor):
                    dim_shape_pairs = [
                        (arg.ndim + d if d < 0 else d, dim_to_shape_id.get(d))
                        for d in dims
                    ]
                    mark_dynamic(arg, dim_shape_pairs)
                elif isinstance(arg, IntermediateTensors):
                    for tensor in arg.tensors.values():
                        dim_shape_pairs = [
                            (tensor.ndim + d if d < 0 else d, dim_to_shape_id.get(d))
                            for d in dims
                        ]
                        mark_dynamic(tensor, dim_shape_pairs)
                else:
                    raise ValueError(
                        f"Unsupported dynamic dimensions {dims} "
                        f"for argument {k} with type {type(arg)}."
                    )

        if mark_unbacked_dims:
            for k, dims_val in mark_unbacked_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
                    dims = [dims_val] if isinstance(dims_val, int) else list(dims_val)
                    if isinstance(arg, torch.Tensor):
                        dims = [arg.ndim + d if d < 0 else d for d in dims]
                        if is_torch_equal_or_newer("2.10.0"):
                            for dim in dims:
                                torch._dynamo.decorators.mark_unbacked(
                                    arg, dim, hint_override=arg.size()[dim]
                                )
                        else:
                            torch._dynamo.decorators.mark_unbacked(arg, dims)

    def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any:
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
        if self.do_not_compile or torch.compiler.is_compiling():
            return self.forward(*args, **kwargs)

        # If skip_compiled is set, bypass compiled model call. This is used e.g. for
        # enc-dec models where tensor shapes/types vary across invocations, preventing
        # the capture of a single computational graph.
        if is_forward_context_available() and get_forward_context().skip_compiled:
            return self.forward(*args, **kwargs)

        # if aot_compiled_fn is set, call it with partition wrapper context.
        # The partition wrapper must be active at runtime for CUDA graph
        # capture to work correctly with inductor graph partitioning.
        if getattr(self, "aot_compiled_fn", None) is not None:
            with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
                return self.aot_compiled_fn(self, *args, **kwargs)

        ds_type = self.compilation_config.dynamic_shapes_config.type
        cache_dir = None
        aot_compilation_path = None
        if envs.VLLM_USE_AOT_COMPILE:
            """
            When using torch.compile in AOT mode, we store the cache artifacts
            under VLLM_CACHE_ROOT/torch_compile_cache/torch_aot_compile/{hash}
            The {hash} contains all of the factors except for the source files
            being traced through, because we don't actually know which source
            files to check at this point (before dynamo runs).
            On loading we will actually look at the source files being traced
            through. If any source file have changed (compared with the
            serialized backend artifacts), then we need to generate a new AOT
            compile artifact from scratch.
            """
            from .caching import aot_compile_hash_factors

            factors: list[str] = aot_compile_hash_factors(self.vllm_config)

            factors.append(_model_hash_key(self.forward))
            hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
            cache_dir = os.path.join(
                envs.VLLM_CACHE_ROOT,
                "torch_compile_cache",
                "torch_aot_compile",
                hash_key,
            )

            # Hash-level dir; shared across ranks on the same node.
            self.compilation_config.local_cache_dir = cache_dir
            inductor_cache = os.path.join(cache_dir, "inductor_cache")
            os.makedirs(inductor_cache, exist_ok=True)
            # Process-wide: post-load execution, CUDA-graph capture, and later
            # autotune/recompile all need to write under {hash}/inductor_cache/.
            # Unconditional because torch's cache_dir() may have pre-filled the
            # /tmp default during import, making setdefault a no-op.
            os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache

            rank = self.vllm_config.parallel_config.rank
            dp_rank = self.vllm_config.parallel_config.data_parallel_index
            cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
            aot_compilation_path = os.path.join(cache_dir, "model")
            if not envs.VLLM_DISABLE_COMPILE_CACHE:
                loaded_fn = _try_load_aot_compiled_fn(self, aot_compilation_path)
                if loaded_fn is not None:
                    self.aot_compiled_fn = loaded_fn
                    self.was_aot_compile_fn_loaded_from_disk = True
                    with (
                        monitor_profiling_run(),
                        maybe_use_cudagraph_partition_wrapper(self.vllm_config),
                    ):
                        output = self.aot_compiled_fn(self, *args, **kwargs)
                    return output

        if self.compiled:
            assert (
                not envs.VLLM_USE_AOT_COMPILE
                or self.vllm_config.compilation_config.backend == "eager"
            )
            return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)  # type: ignore[arg-type]

        # This is the path for the first compilation.
        # the first compilation needs to have dynamic shapes marked
        _mark_dynamic_inputs(
            self,
            ds_type,
            *args,
            **kwargs,
        )

        original_code_object = self.original_code_object()
        logger.debug("Start compiling function %s", original_code_object)

        # we do not want tp delete the original code object entries since
        # we depend on them now to look up cached compiled functions.
        # torch._dynamo.eval_frame.remove_from_cache(original_code_object)

        # collect all relevant files traced by Dynamo,
        # so that the compilation cache can trigger re-compilation
        # properly when any of these files change.

        # 1. the file containing the top-level forward function
        self.compilation_config.traced_files.add(original_code_object.co_filename)

        # 2. every time Dynamo sees a function call, it will inline
        # the function by calling InliningInstructionTranslator.inline_call_
        # we hijack this function to know all the functions called
        # during Dynamo tracing, and their corresponding files
        inline_call = InliningInstructionTranslator.inline_call_

        def patched_inline_call(self_: Any) -> Any:
            code = self_.f_code
            self.compilation_config.traced_files.add(code.co_filename)
            return inline_call(self_)

        # Disable the C++ compilation of symbolic shape guards. C++-fication
        # of symbolic shape guards can improve guard overhead. But, since
        # vllm skip guards anyways, setting this flag to False can improve
        # compile time.
        dynamo_config_patches = {}
        try:
            _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
            dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
        except AttributeError:
            # Note: this config is not available in torch 2.6, we can skip
            # if the config doesn't exist
            logger.debug("enable_cpp_symbolic_shape_guards config not available")

        # Prepare backed_size_oblivious config patch if needed
        fx_config_patches = {}
        if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
            fx_config_patches["backed_size_oblivious"] = True

        # Prepare inductor config patches
        # assume_32bit_indexing is only available in torch 2.10.0+
        inductor_config_patches = {}
        if is_torch_equal_or_newer("2.10.0"):
            inductor_config_patches["assume_32bit_indexing"] = (
                self.compilation_config.dynamic_shapes_config.assume_32_bit_indexing
            )

        with (
            patch.object(
                InliningInstructionTranslator, "inline_call_", patched_inline_call
            ),
            torch._dynamo.config.patch(**dynamo_config_patches),
            maybe_use_cudagraph_partition_wrapper(self.vllm_config),
            torch.fx.experimental._config.patch(**fx_config_patches),
            torch._inductor.config.patch(**inductor_config_patches),
        ):
            use_aot_compile = envs.VLLM_USE_AOT_COMPILE
            if self.vllm_config.compilation_config.backend == "eager":
                logger.warning("Detected eager backend, disabling AOT compile.")
                use_aot_compile = False
            if use_aot_compile:
                # store the path for saving after warmup
                self._aot_compilation_path = aot_compilation_path
                self._aot_cache_dir = cache_dir
                with monitor_torch_compile(
                    self.vllm_config, is_encoder=self._is_encoder
                ):
                    self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
                    compilation_counter.num_aot_compiles += 1
                    # All compilation is done at this point, save the
                    # AOT artifact.
                    self.save_aot_compiled_function()

                with monitor_profiling_run():
                    output = self.aot_compiled_fn(self, *args, **kwargs)
            else:
                with monitor_torch_compile(
                    self.vllm_config,
                    "torch.compile and initial profiling/warmup "
                    "run together took %.2f s in total",
                    is_encoder=self._is_encoder,
                ):
                    output = TorchCompileWithNoGuardsWrapper.__call__(
                        self,  # type: ignore[arg-type]
                        *args,
                        **kwargs,
                    )

        self.compiled = True
        return output

    # triggers VllmSerializableFunction.serialize()
    def save_aot_compiled_function(self: type[_T]) -> None:
        if envs.VLLM_DISABLE_COMPILE_CACHE:
            return

        if self.was_aot_compile_fn_loaded_from_disk:
            logger.debug("AOT compiled function was loaded from cache, skipping save")
            return

        assert (
            self.aot_compiled_fn and self._aot_compilation_path and self._aot_cache_dir
        )

        try:
            os.makedirs(self._aot_cache_dir, exist_ok=True)
            # File saving should be atomic, so we will save to a temporary location
            # first. Should be upstreamed to PyTorch 2.12 as well.
            tmp_file = f"{self._aot_compilation_path}.{os.getpid()}.tmp"
            self.aot_compiled_fn.save_compiled_function(tmp_file)
            os.replace(tmp_file, self._aot_compilation_path)
            compilation_counter.num_aot_artifacts_saved += 1
            logger.info_once(
                "saved AOT compiled function to %s",
                self._aot_compilation_path,
            )
        except Exception as e:
            logger.warning(
                "unable to save AOT compiled function to %s: %s",
                self._aot_compilation_path,
                e,
            )

    cls.__call__ = __call__
    cls.save_aot_compiled_function = save_aot_compiled_function
    return cls

_try_load_aot_compiled_fn

_try_load_aot_compiled_fn(
    model: Any, aot_compilation_path: str
) -> Any | None

Try to load an AOT-compiled function from disk.

Returns the loaded callable on success, or None on failure. Re-raises on failure when VLLM_FORCE_AOT_LOAD is set.

Source code in vllm/compilation/decorators.py
def _try_load_aot_compiled_fn(
    model: Any,
    aot_compilation_path: str,
) -> Any | None:
    """Try to load an AOT-compiled function from disk.

    Returns the loaded callable on success, or None on failure.
    Re-raises on failure when ``VLLM_FORCE_AOT_LOAD`` is set.
    """
    try:
        with monitor_torch_compile(model.vllm_config, is_encoder=model._is_encoder):
            with (
                set_current_vllm_config(model.vllm_config),
                open(aot_compilation_path, "rb") as f,
            ):
                loaded_fn = torch.compiler.load_compiled_function(
                    f, f_globals=model.forward.__globals__
                )
            _verify_source_unchanged(loaded_fn.source_info(), model.vllm_config)
            ds_config = model.compilation_config.dynamic_shapes_config
            if not ds_config.evaluate_guards:
                loaded_fn.disable_guard_check()
            # Eagerly load compiled artifacts now that traced_files
            # is populated by _verify_source_unchanged.
            with maybe_use_cudagraph_partition_wrapper(model.vllm_config):
                loaded_fn._artifacts.compiled_fn.finalize_loading(model.vllm_config)
            compilation_counter.num_aot_artifacts_loaded += 1
            logger.info(
                "Directly load AOT compilation from path %s", aot_compilation_path
            )
        return loaded_fn
    except Exception as e:
        if os.path.exists(aot_compilation_path):
            if isinstance(e, EOFError):
                message = "Compile cache file corrupted."
            else:
                message = str(e)
            logger.warning(
                "Compiling model again due to a load failure from %s, reason: %s",
                aot_compilation_path,
                message,
            )
        if envs.VLLM_FORCE_AOT_LOAD:
            raise e
        return None

ignore_torch_compile

ignore_torch_compile(cls: type[_T]) -> type[_T]

A decorator to ignore support_torch_compile decorator on the class. This is useful when a parent class has a support_torch_compile decorator, but we don't want to compile the class cls that inherits the parent class. This only ignores compiling the forward of the class the decorator is applied to.

If the parent has ignore_torch_compile but the child has support_torch_compile, the child will still be compiled.

If the class has one or more submodules that have support_torch_compile decorator applied, compile will not be ignored for those submodules.

Source code in vllm/compilation/decorators.py
def ignore_torch_compile(cls: type[_T]) -> type[_T]:
    """
    A decorator to ignore support_torch_compile decorator
    on the class. This is useful when a parent class has
    a support_torch_compile decorator, but we don't want to
    compile the class `cls` that inherits the parent class.
    This only ignores compiling the forward of the class the
    decorator is applied to.

    If the parent has ignore_torch_compile but the child has
    support_torch_compile, the child will still be compiled.

    If the class has one or more submodules
    that have support_torch_compile decorator applied, compile will
    not be ignored for those submodules.
    """
    setattr(cls, IGNORE_COMPILE_KEY, True)
    return cls

maybe_use_cudagraph_partition_wrapper

maybe_use_cudagraph_partition_wrapper(
    vllm_config: VllmConfig,
) -> Generator[None, None, None]

Context manager to set/unset customized cudagraph partition wrappers.

If we're using Inductor-based graph partitioning, we currently have the whole fx.Graph before Inductor lowering and the piecewise splitting happens after all graph passes and fusions. Here, we add a custom hook for Inductor to wrap each partition with our static graph wrapper class to maintain more control over static graph capture and replay.

Source code in vllm/compilation/decorators.py
@contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(
    vllm_config: VllmConfig,
) -> Generator[None, None, None]:
    """
    Context manager to set/unset customized cudagraph partition wrappers.

    If we're using Inductor-based graph partitioning, we currently have the
    whole `fx.Graph` before Inductor lowering and the piecewise
    splitting happens after all graph passes and fusions. Here, we add
    a custom hook for Inductor to wrap each partition with our static
    graph wrapper class to maintain more control over static graph
    capture and replay.
    """
    from vllm.config import CUDAGraphMode

    compilation_config = vllm_config.compilation_config
    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
        from torch._inductor.utils import CUDAGraphWrapperMetadata

        from vllm.compilation.cuda_graph import CUDAGraphOptions
        from vllm.platforms import current_platform

        static_graph_wrapper_class = resolve_obj_by_qualname(
            current_platform.get_static_graph_wrapper_cls()
        )

        def customized_cudagraph_wrapper(
            f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata
        ) -> Any:
            partition_id = metadata.partition_index
            num_partitions = metadata.num_partitions
            return static_graph_wrapper_class(
                runnable=f,
                vllm_config=vllm_config,
                runtime_mode=CUDAGraphMode.PIECEWISE,
                cudagraph_options=CUDAGraphOptions(
                    debug_log_enable=partition_id == 0,
                    gc_disable=partition_id != 0,
                    weak_ref_output=partition_id == num_partitions - 1,
                ),
            )

        torch._inductor.utils.set_customized_partition_wrappers(
            customized_cudagraph_wrapper
        )

    yield

    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
        torch._inductor.utils.set_customized_partition_wrappers(None)

should_torch_compile_mm_encoder

should_torch_compile_mm_encoder(
    vllm_config: VllmConfig,
) -> bool

Callable to be passed to @support_torch_compile's enable_if argument.

Source code in vllm/compilation/decorators.py
def should_torch_compile_mm_encoder(vllm_config: VllmConfig) -> bool:
    """Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
    return vllm_config.compilation_config.compile_mm_encoder

support_torch_compile

support_torch_compile(
    *, enable_if: Callable[[VllmConfig], bool] | None = None
) -> Callable[[type[_T]], type[_T]]
support_torch_compile(
    *,
    dynamic_arg_dims: dict[
        str, int | list[int] | dict[int, str]
    ]
    | None,
) -> Callable[[type[_T]], type[_T]]
support_torch_compile(
    *, mark_unbacked_dims: dict[str, int | list[int]] | None
) -> Callable[[type[_T]], type[_T]]
support_torch_compile(
    *,
    dynamic_arg_dims: dict[
        str, int | list[int] | dict[int, str]
    ]
    | None,
    mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[type[_T]], type[_T]]
support_torch_compile(cls: type[_T]) -> type[_T]
support_torch_compile(
    cls: type[_T] | None = None,
    *,
    dynamic_arg_dims: dict[
        str, int | list[int] | dict[int, str]
    ]
    | None = None,
    mark_unbacked_dims: dict[str, int | list[int]]
    | None = None,
    enable_if: Callable[[VllmConfig], bool] | None = None,
    is_encoder: bool = False,
) -> Callable[[type[_T]], type[_T]] | type[_T]

A decorator to add support for compiling the forward method of a class.

Usage 1: use directly as a decorator without arguments:

@support_torch_compile
class MyModel(nn.Module):
    def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...

Usage 2: use as a decorator with arguments:

@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
class MyModel(nn.Module):
    def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...

dynamic_arg_dims is a dictionary that maps argument names to the dynamic dimensions of the argument. The value can be: - int: a single dimension index (e.g., 0) - list[int]: multiple dimension indices (e.g., [0, 1]) - dict[int, str]: dimension to shape_id mapping for shape relations (e.g., {0: "b"}). Dimensions with the same shape_id share the same unbacked symbol.

if dynamic_arg_dims is None, it is inferred from the type annotation of the forward method, based on the following default rules:

  • if the argument is annotated as torch.Tensor or Optional[torch.Tensor], the first dimension will be marked as dynamic.
  • if the argument is annotated as IntermediateTensors, the first dimension of all the tensors in the intermediate tensors will be marked as dynamic.

During runtime, when we actually mark dimensions of tensors, it depends on the value of arguments:

  • if it is a single integer (can be negative), the corresponding dimension of the argument will be marked as dynamic.
  • if it is None, ignored.
  • if it is IntermediateTensors, all the tensors in the intermediate tensors will be marked as dynamic.
  • otherwise, it will raise an error.

NOTE: if an argument is None, it should always be passed as None during the lifetime of the model, otherwise, it cannot be captured as a single computation graph.

enable_if is a function that takes a VllmConfig object as input and returns a boolean value indicating whether to compile the model or not. This is useful if you want to compile the model only when certain conditions are met.

mark_unbacked_dims is a dictionary that maps argument names with a dynamic dim to be decorated with mark_unbacked. This is useful if we would like to enforce that dynamo does not specialize on 0/1 values in the case of dummy input such as for vision model compilation

is_encoder marks this module as a portion of an multimodal encoder. When True, the compile range upper bound is set to MAX_INT32 instead of max_num_batched_tokens, since encoder input shapes are unpredictable. This is typically used for vision encoder sub-modules in multimodal models.

shape_invariants is a function that gets compiled right before forward. The function should have the torch._check calls that are needed to set the relationships between different input sizes. For example: torch._check(input_ids.size()[0] == inputs_embeds.size()[0]) This enforces constraints on the symbolic shapes without hardcoding specific values. It is needed for some models to avoid data dependent errors and maximize perf when unbacked shapes are used.

Source code in vllm/compilation/decorators.py
def support_torch_compile(
    cls: type[_T] | None = None,
    *,
    dynamic_arg_dims: dict[str, int | list[int] | dict[int, str]] | None = None,
    mark_unbacked_dims: dict[str, int | list[int]] | None = None,
    enable_if: Callable[[VllmConfig], bool] | None = None,
    is_encoder: bool = False,
) -> Callable[[type[_T]], type[_T]] | type[_T]:
    """
    A decorator to add support for compiling the forward method of a class.

    Usage 1: use directly as a decorator without arguments:

    ```python
    @support_torch_compile
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
    ```

    Usage 2: use as a decorator with arguments:

    ```python
    @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
    ```

    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The value can be:
    - int: a single dimension index (e.g., 0)
    - list[int]: multiple dimension indices (e.g., [0, 1])
    - dict[int, str]: dimension to shape_id mapping for shape relations
      (e.g., {0: "b"}). Dimensions with the same shape_id share the same
      unbacked symbol.

    if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
    of the `forward` method, based on the following default rules:

    - if the argument is annotated as `torch.Tensor` or
        `Optional[torch.Tensor]`, the first dimension will be
        marked as dynamic.
    - if the argument is annotated as `IntermediateTensors`, the first
        dimension of all the tensors in the intermediate tensors
        will be marked as dynamic.

    During runtime, when we actually mark dimensions of tensors,
     it depends on the value of arguments:

    - if it is a single integer (can be negative), the corresponding dimension
        of the argument will be marked as dynamic.
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.

    `enable_if` is a function that takes a `VllmConfig` object as input and
    returns a boolean value indicating whether to compile the model or not.
    This is useful if you want to compile the model only when certain
    conditions are met.

    `mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
    dim to be decorated with `mark_unbacked`.  This is useful if we would like to
    enforce that dynamo does not specialize on 0/1 values in the case of dummy input
    such as for vision model compilation

    `is_encoder` marks this module as a portion of an multimodal encoder.
    When True, the compile range upper bound is set to MAX_INT32 instead of
    max_num_batched_tokens, since encoder input shapes are unpredictable.
    This is typically used for vision encoder sub-modules in multimodal models.

    `shape_invariants` is a function that gets compiled right before forward.
    The function should have the torch._check calls that are needed to set
    the relationships between different input sizes. For example:
            torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
    This enforces constraints on the symbolic shapes without hardcoding
    specific values. It is needed for some models to avoid data dependent
    errors and maximize perf when unbacked shapes are used.
    """

    def cls_decorator_helper(cls: type[_T]) -> type[_T]:
        # helper to pass `dynamic_arg_dims` to `_support_torch_compile`
        # to avoid too much indentation for `_support_torch_compile`
        if not hasattr(cls, "forward"):
            raise TypeError("decorated class should have a forward method.")
        sig = inspect.signature(cls.forward)
        inferred_dynamic_arg_dims = dynamic_arg_dims
        if inferred_dynamic_arg_dims is None:
            inferred_dynamic_arg_dims = {}
            for k, v in sig.parameters.items():
                if v.annotation in [
                    torch.Tensor,
                    torch.Tensor | None,
                    torch.FloatTensor,
                    torch.FloatTensor | None,
                    IntermediateTensors,
                    IntermediateTensors | None,
                ]:
                    inferred_dynamic_arg_dims[k] = 0

            logger.debug(
                ("Inferred dynamic dimensions for forward method of %s: %s"),
                cls,
                list(inferred_dynamic_arg_dims.keys()),
            )

        if len(inferred_dynamic_arg_dims) == 0:
            raise ValueError(
                "No dynamic dimensions found in the forward method of "
                f"{cls}. Please provide dynamic_arg_dims explicitly."
            )

        for k in inferred_dynamic_arg_dims:
            if k not in sig.parameters:
                raise ValueError(
                    f"Argument {k} not found in the forward method of {cls}"
                )

        return _support_torch_compile(
            cls,
            inferred_dynamic_arg_dims,
            mark_unbacked_dims,
            enable_if,
            is_encoder,
        )

    if cls is not None:
        # use `support_torch_compile` as a decorator without arguments
        assert isinstance(cls, type)
        return cls_decorator_helper(cls)

    return cls_decorator_helper