-
Notifications
You must be signed in to change notification settings - Fork 278
Add persistent program cache for Program.compile #1912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,7 +85,12 @@ cdef class Program: | |
| self._h_nvvm.reset() | ||
|
|
||
| def compile( | ||
| self, target_type: str, name_expressions: tuple | list = (), logs = None | ||
| self, | ||
| target_type: str, | ||
| name_expressions: tuple | list = (), | ||
| logs=None, | ||
| *, | ||
| cache: "ProgramCacheResource | None" = None, | ||
| ) -> ObjectCode: | ||
| """Compile the program to the specified target type. | ||
|
|
||
|
|
@@ -98,13 +103,99 @@ cdef class Program: | |
| Used for template instantiation and similar cases. | ||
| logs : object, optional | ||
| Object with a ``write`` method to receive compilation logs. | ||
| cache : :class:`~cuda.core.utils.ProgramCacheResource`, optional | ||
| If provided, the compiled binary is looked up in ``cache`` via a | ||
| key derived from the program's code, options, and ``target_type``. | ||
| On a hit the cached bytes are wrapped in a fresh | ||
| :class:`~cuda.core.ObjectCode` (with the same ``target_type`` | ||
| and ``ProgramOptions.name``) and returned without re-compiling; | ||
| on a miss the compile output is stored as raw bytes (the cache | ||
| extracts ``bytes(object_code.code)``). Passing a non-empty | ||
| ``name_expressions`` together with ``cache=`` raises | ||
| ``ValueError``: NVRTC populates | ||
| ``ObjectCode.symbol_mapping`` at compile time and that mapping | ||
| is not carried in the binary the cache stores, so cache hits | ||
| would silently miss ``get_kernel(name_expression)`` lookups. | ||
| Options that require an ``extra_digest`` (``include_path``, | ||
| ``pre_include``, ``pch``, ``use_pch``, ``pch_dir``, NVVM | ||
| ``use_libdevice=True``, or NVRTC ``options.name`` with a | ||
| directory component) raise ``ValueError`` via | ||
| :func:`~cuda.core.utils.make_program_cache_key`; for those | ||
| compiles, use the manual ``make_program_cache_key(...)`` | ||
| pattern directly. | ||
|
|
||
| Returns | ||
| ------- | ||
| :class:`~cuda.core.ObjectCode` | ||
| The compiled object code. | ||
| """ | ||
| return Program_compile(self, target_type, name_expressions, logs) | ||
| if cache is None: | ||
| return _program_compile_uncached(self, target_type, name_expressions, logs) | ||
|
|
||
| # ``name_expressions`` is incompatible with the cache: NVRTC | ||
| # populates ``ObjectCode.symbol_mapping`` from name-expression | ||
| # mangling at compile time, and that mapping isn't carried in | ||
| # the binary bytes the cache stores. Without this guard the | ||
| # first call (cache miss) would return an ObjectCode with | ||
| # symbol_mapping populated, while every subsequent call (hit) | ||
| # would return one without -- silently breaking later | ||
| # ``get_kernel(name_expression)`` lookups that work on the | ||
| # uncached path. Fail loud here instead. | ||
| if name_expressions: | ||
| raise ValueError( | ||
| "Program.compile(cache=...) does not support name_expressions: " | ||
| "ObjectCode.symbol_mapping is populated by NVRTC at compile " | ||
| "time and is not preserved across a cache round-trip, so cache " | ||
| "hits would silently break get_kernel(name_expression) lookups " | ||
| "that the uncached path supports. Compile without cache= when " | ||
| "name_expressions are needed, or look up mangled symbols by " | ||
| "hand from the cached ObjectCode." | ||
| ) | ||
|
|
||
| # Deferred import to avoid a circular import between _program and | ||
| # cuda.core.utils._program_cache (the cache module already imports | ||
| # ProgramOptions from this module). Import from the leaf module so | ||
| # tests that monkeypatch make_program_cache_key via that path | ||
| # intercept reliably. | ||
| from cuda.core.utils._program_cache import make_program_cache_key | ||
|
|
||
| # ``self._code`` is always stored as bytes (see ``Program_init``), | ||
| # but ``make_program_cache_key`` only accepts bytes when | ||
| # ``code_type == "nvvm"`` -- c++/ptx must be ``str``. Decode back | ||
| # to the original str for the NVRTC/linker paths so the generated | ||
| # key matches keys callers build by passing the str source | ||
| # directly. | ||
| code_for_key = self._code if self._code_type == "nvvm" else self._code.decode("utf-8") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consideration: UTF-8 decode introduces a new failure mode in the cache path. If the source code contains non-UTF-8 bytes (e.g. Latin-1 encoded comments), this This is likely rare in practice (CUDA source is almost always ASCII/UTF-8), but worth documenting in the |
||
|
|
||
| key = make_program_cache_key( | ||
| code=code_for_key, | ||
| code_type=self._code_type, | ||
| options=self._options, | ||
| target_type=target_type, | ||
| ) | ||
| hit_bytes = cache.get(key) | ||
| if hit_bytes is not None: | ||
| # The uncached NVRTC path warns when the active driver can't | ||
| # load freshly-generated PTX; that loadability is a property | ||
| # of the driver, not of how the bytes were produced, so the | ||
|
Comment on lines
+162
to
+180
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consideration: When the cache hits, no compilation occurs, so
|
||
| # warning applies equally to cached PTX. Mirror it here so a | ||
| # cache hit doesn't silently hide an incompatibility that the | ||
| # uncached call would have surfaced. | ||
|
leofang marked this conversation as resolved.
|
||
| if ( | ||
| self._backend == "NVRTC" | ||
| and target_type == "ptx" | ||
| and not _can_load_generated_ptx() | ||
| ): | ||
| warn( | ||
| "The CUDA driver version is older than the backend version. " | ||
| "The generated ptx will not be loadable by the current driver.", | ||
| stacklevel=2, | ||
| category=RuntimeWarning, | ||
| ) | ||
| return ObjectCode._init(hit_bytes, target_type, name=self._options.name) | ||
| compiled = _program_compile_uncached(self, target_type, name_expressions, logs) | ||
| cache[key] = compiled | ||
| return compiled | ||
|
|
||
| @property | ||
| def pch_status(self) -> str | None: | ||
|
|
@@ -503,6 +594,19 @@ class ProgramOptions: | |
| # Private Classes and Helper Functions | ||
| # ============================================================================= | ||
|
|
||
|
|
||
| def _program_compile_uncached(program, target_type, name_expressions, logs): | ||
| """Run ``Program_compile`` without the cache wrapper. | ||
|
|
||
| Module-level Python function so tests can monkeypatch it from | ||
| ``cuda.core._program`` to avoid invoking NVRTC when exercising the cache | ||
| wrapper in :meth:`Program.compile`. ``Program`` itself is a ``cdef class`` | ||
| and its methods cannot be reassigned from Python, so the seam must live | ||
| outside the class. | ||
| """ | ||
| return Program_compile(program, target_type, name_expressions, logs) | ||
|
|
||
|
|
||
| # Module-level state for NVVM lazy loading | ||
| _nvvm_module = None | ||
| _nvvm_import_attempted = False | ||
|
|
@@ -577,8 +681,16 @@ cdef inline void _process_define_macro(list options, object macro) except *: | |
| raise RuntimeError(f"Expected define_macro {union_type}, list[{union_type}], got {macro}") | ||
|
|
||
|
|
||
| cpdef bint _can_load_generated_ptx() except? -1: | ||
| """Check if the driver can load PTX generated by the current NVRTC version.""" | ||
| def _can_load_generated_ptx(): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| """Check if the driver can load PTX generated by the current NVRTC version. | ||
|
|
||
| Defined as plain ``def`` (not ``cpdef``) so monkeypatching via | ||
| ``cuda.core._program._can_load_generated_ptx`` reliably intercepts | ||
| in-module callers. Cython early-binds ``cpdef`` calls within the same | ||
| module to the C entry point, which would silently bypass the test | ||
| seam used by ``Program.compile(cache=...)`` to mirror the PTX | ||
| loadability warning on cache hit. | ||
| """ | ||
| drv = driver_version() | ||
| nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion()) | ||
| return (nvrtc_major, nvrtc_minor, 0) <= drv | ||
|
|
@@ -618,6 +730,7 @@ cdef inline int Program_init(Program self, object code, str code_type, object op | |
|
|
||
| self._options = options = check_or_create_options(ProgramOptions, options, "Program options") | ||
| code_type = code_type.lower() | ||
| self._code_type = code_type | ||
| self._compile_lock = threading.Lock() | ||
| self._use_libdevice = False | ||
| self._libdevice_added = False | ||
|
|
@@ -638,16 +751,18 @@ cdef inline int Program_init(Program self, object code, str code_type, object op | |
| HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram( | ||
| &nvrtc_prog, code_ptr, name_ptr, 0, NULL, NULL)) | ||
| self._h_nvrtc = create_nvrtc_program_handle(nvrtc_prog) | ||
| self._nvrtc_code = code_bytes | ||
| self._code = code_bytes | ||
| self._backend = "NVRTC" | ||
| self._linker = None | ||
|
|
||
| elif code_type == "ptx": | ||
| assert_type(code, str) | ||
| if options.extra_sources is not None: | ||
| raise ValueError("extra_sources is not supported by the PTX backend.") | ||
| code_bytes = code.encode() | ||
| self._code = code_bytes | ||
| self._linker = Linker( | ||
| ObjectCode._init(code.encode(), code_type), options=_translate_program_options(options) | ||
| ObjectCode._init(code_bytes, code_type), options=_translate_program_options(options) | ||
| ) | ||
| self._backend = self._linker.backend | ||
|
|
||
|
|
@@ -657,10 +772,13 @@ cdef inline int Program_init(Program self, object code, str code_type, object op | |
| code = code.encode("utf-8") | ||
| elif not isinstance(code, (bytes, bytearray)): | ||
| raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray") | ||
| self._code = bytes(code) # Coerce bytearray -> bytes so retention type is stable | ||
|
|
||
| code_ptr = <const char*>(<bytes>code) | ||
| # Use self._code (strictly bytes) for the C pointer so a bytearray | ||
| # input doesn't trip the `<bytes>code` cast at runtime. | ||
| code_ptr = <const char*>self._code | ||
| name_ptr = <const char*>options._name | ||
| code_len = len(code) | ||
| code_len = len(self._code) | ||
|
|
||
| with nogil: | ||
| HANDLE_RETURN_NVVM(NULL, cynvvm.nvvmCreateProgram(&nvvm_prog)) | ||
|
|
@@ -832,7 +950,7 @@ cdef object Program_compile_nvrtc(Program self, str target_type, object name_exp | |
| HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcSetPCHHeapSize(required)) | ||
|
|
||
| cdef cynvrtc.nvrtcProgram retry_prog | ||
| cdef const char* code_ptr = <const char*>self._nvrtc_code | ||
| cdef const char* code_ptr = <const char*>self._code | ||
| cdef const char* name_ptr = <const char*>self._options._name | ||
| with nogil: | ||
| HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram( | ||
|
|
||
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from cuda.core._memoryview import ( | ||
| StridedMemoryView, | ||
| args_viewable_as_strided_memory, | ||
| ) | ||
| from cuda.core.utils._program_cache import ( | ||
| FileStreamProgramCache, | ||
| InMemoryProgramCache, | ||
| ProgramCacheResource, | ||
| make_program_cache_key, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "FileStreamProgramCache", | ||
| "InMemoryProgramCache", | ||
| "ProgramCacheResource", | ||
| "StridedMemoryView", | ||
| "args_viewable_as_strided_memory", | ||
| "make_program_cache_key", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Persistent program cache for cuda.core. | ||
|
|
||
| Public surface: | ||
|
|
||
| * :class:`ProgramCacheResource` -- bytes-in / bytes-out ABC. | ||
| * :class:`InMemoryProgramCache` -- thread-safe LRU dict-backed cache. | ||
| * :class:`FileStreamProgramCache` -- atomic, multi-process directory cache. | ||
| * :func:`make_program_cache_key` -- key derivation for arbitrary | ||
| ``Program`` configurations. | ||
|
|
||
| The package is split into submodules by concern. Tests that need to | ||
| monkeypatch internals (Windows flag, version probes, helpers, ...) | ||
| should reach into the owning submodule (e.g. | ||
| ``_program_cache._file_stream._IS_WINDOWS``, | ||
| ``_program_cache._keys._linker_backend_and_version``) rather than the | ||
| package object: the symbols re-exported here are only convenience | ||
| aliases and don't intercept calls within the submodules. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from ._abc import ProgramCacheResource | ||
| from ._file_stream import FileStreamProgramCache | ||
| from ._in_memory import InMemoryProgramCache | ||
| from ._keys import make_program_cache_key | ||
|
|
||
| __all__ = [ | ||
| "FileStreamProgramCache", | ||
| "InMemoryProgramCache", | ||
| "ProgramCacheResource", | ||
| "make_program_cache_key", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: I need to address this after 1.0 is out, xref: cupy/cupy#9801