diff --git a/CHANGELOG.md b/CHANGELOG.md index a456453c..19cd9f5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [2.12.1] - 2026-04-21 + +### Fixed + +- `tensorizer` is now compatible with Python 3.14 + ## [2.12.0] - 2025-08-20 ### Added @@ -509,6 +515,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `get_gpu_name` - `no_init_or_tensor` +[2.12.1]: https://github.com/coreweave/tensorizer/compare/v2.12.0...v2.12.1 [2.12.0]: https://github.com/coreweave/tensorizer/compare/v2.11.1...v2.12.0 [2.11.1]: https://github.com/coreweave/tensorizer/compare/v2.11.0...v2.11.1 [2.11.0]: https://github.com/coreweave/tensorizer/compare/v2.10.1...v2.11.0 diff --git a/tensorizer/_crypt/_encryption.py b/tensorizer/_crypt/_encryption.py index dc519a44..d150ea24 100644 --- a/tensorizer/_crypt/_encryption.py +++ b/tensorizer/_crypt/_encryption.py @@ -4,6 +4,7 @@ import enum import io import mmap +import sys import typing from concurrent.futures import ThreadPoolExecutor from contextlib import AbstractContextManager @@ -31,6 +32,9 @@ import libnacl from libnacl import nacl +if sys.version_info >= (3, 14): + import annotationlib + try: from ._cgroup_cpu_count import effective_cpu_count from ._exceptions import CryptographyError @@ -346,6 +350,27 @@ def init_sodium_memzero(): sodium_memzero = init_sodium_memzero() +def _pop_annotations(dct: dict) -> dict: + """ + Extract annotations from a metaclass namespace dict. + + In Python 3.14+ (PEP 649), annotations are lazily evaluated and stored + as an *annotate function* instead of the ``__annotations__`` attribute. + """ + if "__annotations__" in dct: + return dct.pop("__annotations__") + elif ( + sys.version_info >= (3, 14) + and (annotate := annotationlib.get_annotate_from_class_namespace(dct)) + is not None + ): + return annotationlib.call_annotate_function( + annotate, format=annotationlib.Format.VALUE + ) + else: + return {} + + class Constants(type): @staticmethod def _get_constant(name, typ) -> int: @@ -355,7 +380,7 @@ def _get_constant(name, typ) -> int: return getter() def __new__(cls, name: str, bases: tuple, dct: dict) -> NamedTuple: - annotations = dct.pop("__annotations__", {}) + annotations = _pop_annotations(dct) entries = {} for constant_name, constant_type in dct.items(): if constant_name.startswith("_"): @@ -364,7 +389,7 @@ def __new__(cls, name: str, bases: tuple, dct: dict) -> NamedTuple: constant_name.lower(), constant_type ) constant_class = typing.cast( - type, NamedTuple(name, **{k: annotations[k] for k in entries}) + type, NamedTuple(name, [(k, annotations[k]) for k in entries]) ) return constant_class(**entries) diff --git a/tensorizer/_version.py b/tensorizer/_version.py index 95a6d3a7..96fc614c 100644 --- a/tensorizer/_version.py +++ b/tensorizer/_version.py @@ -1 +1 @@ -__version__ = "2.12.0" +__version__ = "2.12.1" diff --git a/tests/requirements.txt b/tests/requirements.txt index a4387394..a0be55ba 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -2,3 +2,4 @@ transformers>=4.27.1 moto[s3,server]>=4.1.4,<5.0.0 redis>=5.0.0 hiredis>=2.2.0 +accelerate>=1.9.0 diff --git a/tests/test_serialization.py b/tests/test_serialization.py index ff50bdf7..6df018e3 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -19,6 +19,7 @@ from unittest.mock import patch import torch +import transformers import tensorizer @@ -161,6 +162,21 @@ def model_digest( return {k: TensorInfo.from_tensor(v) for k, v in orig_sd.items()} +# Key suffixes that older serialized artifacts may contain but that +# newer transformers versions no longer register. +# If the transformers version is one that no longer registers these, +# `check_deserialized` ignores any mismatches caused by them being present. +_LEGACY_IGNORED_KEY_SUFFIXES: Tuple[str, ...] +if int(transformers.__version__.partition(".")[0]) < 5: + _LEGACY_IGNORED_KEY_SUFFIXES = () +else: + _LEGACY_IGNORED_KEY_SUFFIXES = (".attn.attention.masked_bias",) + + +def _is_legacy_ignored_key(k: str) -> bool: + return k.endswith(_LEGACY_IGNORED_KEY_SUFFIXES) + + def check_deserialized( test_case: unittest.TestCase, deserialized: TensorDeserializer, @@ -171,13 +187,19 @@ def check_deserialized( orig_sd = model_digest(model_name, include_non_persistent_buffers) if not allow_subset: - test_case.assertEqual( + deserialized_filtered = { + k for k in deserialized.keys() if not _is_legacy_ignored_key(k) + } + test_case.assertCountEqual( orig_sd.keys(), - deserialized.keys(), + deserialized_filtered, "List of deserialized keys doesn't match list of original keys", ) for k, v in deserialized.items(): + if _is_legacy_ignored_key(k): + continue + test_case.assertIn( k, orig_sd,