Skip to content

[minor] Add custom calibration backend registry#1281

Open
Fridah-nv wants to merge 6 commits intomainfrom
fridah/calib-registry
Open

[minor] Add custom calibration backend registry#1281
Fridah-nv wants to merge 6 commits intomainfrom
fridah/calib-registry

Conversation

@Fridah-nv
Copy link
Copy Markdown
Contributor

@Fridah-nv Fridah-nv commented Apr 16, 2026

What does this PR do?

Type of change: ?

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added a public backend-specific calibrator registration API to support FP8 scale-sweep calibration, allowing backends to supply custom calibrators used during FP8 tuning.
  • Tests

    • Added unit tests confirming registry insertion/overwrite, that registered calibrators are invoked when FP8 scale-sweep is enabled, are not invoked when disabled, and that calibration falls back to defaults when no backend is registered.

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv Fridah-nv requested review from realAsma and sugunav14 April 16, 2026 20:53
@Fridah-nv Fridah-nv requested a review from a team as a code owner April 16, 2026 20:53
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 16, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: f1cda356-e3dd-4b52-9c96-531653f21e9a

📥 Commits

Reviewing files that changed from the base of the PR and between 57b33f3 and 33c3528.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/model_calib.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/model_calib.py

📝 Walkthrough

Walkthrough

Adds a module-level FP8 scale-sweep calibrator extension: a public CalibratorFactory type alias, a _FP8_SWEEP_CALIBRATOR_REGISTRY, and _register_fp8_sweep_calibrator. mse_calibrate() now consults this registry when fp8_scale_sweep=True and can replace module calibrators with backend-specific factories.

Changes

Cohort / File(s) Summary
FP8 sweep calibrator registration
modelopt/torch/quantization/model_calib.py
Added exported CalibratorFactory type alias and added "CalibratorFactory" to __all__; introduced _FP8_SWEEP_CALIBRATOR_REGISTRY and the registration helper _register_fp8_sweep_calibrator(backend, calibrator_factory); updated mse_calibrate() to, when fp8_scale_sweep=True, look up a module's backend in the registry and, if found, replace the module's _calibrator with the factory-produced calibrator (skipping the built-in NVFP4-specific replacement).
Tests for registration and dispatch
tests/unit/torch/quantization/test_mse_calibrator.py
Added TestRegisterFP8SweepCalibrator which snapshots/restores module-level registries around each test; verifies registry insert/overwrite behavior; asserts that a registered factory is invoked exactly when fp8_scale_sweep=True and not when disabled; and checks unregistered backends fall back to the default MseCalibrator.

Sequence Diagram(s)

sequenceDiagram
    participant Client as Client
    participant Reg as FP8 Registry
    participant Cal as mse_calibrate
    participant Fac as Calibrator Factory
    participant Mod as TensorQuantizer/Module

    Client->>Reg: _register_fp8_sweep_calibrator(backend, factory)
    Note over Reg: store factory by backend

    rect rgba(100,150,240,0.5)
    Client->>Cal: mse_calibrate(..., fp8_scale_sweep=True)
    Cal->>Mod: enumerate enabled quantizers
    Mod->>Cal: provide backend, initial_amax, axis, quant_func
    Cal->>Reg: lookup backend
    alt backend registered
        Reg-->>Cal: factory
        Cal->>Fac: call factory(initial_amax, axis, partial(quant_func, quantizer=Mod))
        Fac-->>Cal: custom _Calibrator
        Cal->>Mod: set module._calibrator = custom _Calibrator
    else not registered
        Reg-->>Cal: no entry
        Cal->>Mod: leave/assign default MseCalibrator
    end
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 52.94% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: adding a custom calibration backend registry for FP8 sweep calibrators.
Security Anti-Patterns ✅ Passed Pull request introduces safe calibrator registry mechanism with no security anti-patterns detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fridah/calib-registry

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 16, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1281/

Built to branch gh-pages at 2026-04-19 06:06 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
modelopt/torch/quantization/model_calib.py (1)

64-65: Type hint mismatch with actual usage.

The registry is typed as dict[str, type] but stores callable factory functions, not necessarily types/classes. The docstring correctly describes it as a callable with signature (amax, axis, quant_func).

Suggested fix for type accuracy
+from typing import Callable
+
 # Registry for backends that provide a custom calibrator factory for mse_calibrate().
-_FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, type] = {}
+_FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, Callable] = {}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/model_calib.py` around lines 64 - 65, The
registry _FP8_SWEEP_CALIBRATOR_REGISTRY is annotated as dict[str, type] but
actually stores callable factory functions; update its type hint to reflect a
callable signature (import Callable/Any from typing) e.g. dict[str,
Callable[[float, int, Callable[..., Any]], Any]] or similar to match the
described factory signature (amax, axis, quant_func), and adjust any related
references (e.g., usages in mse_calibrate) to satisfy the new annotation.
tests/unit/torch/quantization/test_mse_calibrator.py (1)

596-617: Consider verifying the factory receives correct arguments.

The test verifies the factory is called once, but doesn't validate that the arguments passed to the factory are correct (e.g., that amax is a tensor, quant_func is callable). This is optional since the test already confirms the dispatch path works.

Optional enhancement to validate arguments
         def my_factory(amax, axis, quant_func):
-            factory_calls.append(amax)
+            factory_calls.append((amax, axis, quant_func))
             return _RecordingCalibrator(amax=amax, axis=axis, quant_func=quant_func)
 
         register_fp8_sweep_calibrator("_test_dispatch", my_factory)
         self._quantize_and_calibrate("_test_dispatch", fp8_scale_sweep=True)
 
         assert len(factory_calls) == 1
+        amax, axis, quant_func = factory_calls[0]
+        assert isinstance(amax, torch.Tensor)
+        assert callable(quant_func)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/quantization/test_mse_calibrator.py` around lines 596 - 617,
Update test_mse_calibrate_dispatches_to_registered_factory to assert the factory
receives the expected argument types/values: inside my_factory (registered via
register_fp8_sweep_calibrator) capture the passed amax, axis, and quant_func
into factory_calls (or a separate list) and add assertions after
self._quantize_and_calibrate that the captured amax is a tensor (or nd-array as
expected), axis matches the expected axis value, and quant_func is callable;
reference the test function name, my_factory, _RecordingCalibrator, and
register_fp8_sweep_calibrator when locating and modifying the test.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 64-65: The registry _FP8_SWEEP_CALIBRATOR_REGISTRY is annotated as
dict[str, type] but actually stores callable factory functions; update its type
hint to reflect a callable signature (import Callable/Any from typing) e.g.
dict[str, Callable[[float, int, Callable[..., Any]], Any]] or similar to match
the described factory signature (amax, axis, quant_func), and adjust any related
references (e.g., usages in mse_calibrate) to satisfy the new annotation.

In `@tests/unit/torch/quantization/test_mse_calibrator.py`:
- Around line 596-617: Update
test_mse_calibrate_dispatches_to_registered_factory to assert the factory
receives the expected argument types/values: inside my_factory (registered via
register_fp8_sweep_calibrator) capture the passed amax, axis, and quant_func
into factory_calls (or a separate list) and add assertions after
self._quantize_and_calibrate that the captured amax is a tensor (or nd-array as
expected), axis matches the expected axis value, and quant_func is callable;
reference the test function name, my_factory, _RecordingCalibrator, and
register_fp8_sweep_calibrator when locating and modifying the test.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: f432804b-9113-44fe-aa54-f758151cde31

📥 Commits

Reviewing files that changed from the base of the PR and between 6ded36b and c7b5044.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/model_calib.py
  • tests/unit/torch/quantization/test_mse_calibrator.py

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 16, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.15%. Comparing base (e9a4989) to head (33c3528).
⚠️ Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1281      +/-   ##
==========================================
- Coverage   76.61%   76.15%   -0.47%     
==========================================
  Files         459      459              
  Lines       49153    49164      +11     
==========================================
- Hits        37661    37439     -222     
- Misses      11492    11725     +233     
Flag Coverage Δ
examples 41.88% <41.66%> (-0.15%) ⬇️
gpu 58.90% <75.00%> (-0.51%) ⬇️
regression 14.99% <41.66%> (+0.08%) ⬆️
unit 52.97% <100.00%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@realAsma
Copy link
Copy Markdown
Contributor

LGTM!

Comment thread modelopt/torch/quantization/model_calib.py Outdated
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/quantization/model_calib.py (1)

374-379: Validate the factory output before storing it.

This path trusts the factory result blindly, but the later calibration flow assumes _Calibrator behavior. A quick check here would turn plugin mistakes into an immediate, actionable error instead of a delayed AttributeError deeper in calibration.

Suggested fix
-                        module._calibrator = backend_factory(
+                        calibrator = backend_factory(
                             initial_amax,
                             module._calibrator._axis,
                             partial(_mse_quant_func, quantizer=module),
                         )
+                        if not isinstance(calibrator, _Calibrator):
+                            raise TypeError(
+                                "Registered FP8 sweep calibrator must return _Calibrator, "
+                                f"got {type(calibrator).__name__}"
+                            )
+                        module._calibrator = calibrator
                         continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/model_calib.py` around lines 374 - 379, The code
assigns module._calibrator = backend_factory(...) without validating it; change
this to check that the returned object implements the expected _Calibrator
interface (e.g., has required attributes/methods like _axis and whatever methods
the calibration flow expects) after calling backend_factory in the branch that
creates a new calibrator for module using _mse_quant_func; if the check fails,
raise a descriptive error (including the factory identity and the missing
attribute names) so plugin mistakes fail fast instead of causing AttributeError
later in the calibration pipeline.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 365-372: The code currently assumes module.backend is a str or
None before looking up _FP8_SWEEP_CALIBRATOR_REGISTRY, but an unhashable/non-str
value can cause a TypeError; update the fp8_scale_sweep branch to first fetch
_backend = getattr(module, "backend", None) and only perform the registry lookup
when isinstance(_backend, str) (otherwise treat backend_factory as None) so that
unhashable values skip the registry lookup and fall back to the default
calibrator (referencing fp8_scale_sweep, module, _backend, and
_FP8_SWEEP_CALIBRATOR_REGISTRY).

---

Nitpick comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 374-379: The code assigns module._calibrator =
backend_factory(...) without validating it; change this to check that the
returned object implements the expected _Calibrator interface (e.g., has
required attributes/methods like _axis and whatever methods the calibration flow
expects) after calling backend_factory in the branch that creates a new
calibrator for module using _mse_quant_func; if the check fails, raise a
descriptive error (including the factory identity and the missing attribute
names) so plugin mistakes fail fast instead of causing AttributeError later in
the calibration pipeline.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: f981ca33-0f21-458e-a8d5-a19da0795815

📥 Commits

Reviewing files that changed from the base of the PR and between c7b5044 and 0e43a87.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/model_calib.py

Comment thread modelopt/torch/quantization/model_calib.py
_FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, CalibratorFactory] = {}


def register_fp8_sweep_calibrator(backend: str, calibrator_factory: CalibratorFactory) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we intend to make it an API that our external users would use?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Advanced users can use it for new calibration algorithms, though I not sure if we have such use case in external usages.

Copy link
Copy Markdown
Collaborator

@shengliangxu shengliangxu Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, then I would recommend that we keep it as a private API, prefix it with _

def _register_fp8_sweep_calibrator

If we we want to expose it as public API, then we also need documentation about it. And would need to at least implement some calibrator_facotry for user to choose.

How do you think?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, make the update

kevalmorabia97 and others added 2 commits April 18, 2026 01:39
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>

__all__ = [
"CalibratorFactory",
"_register_fp8_sweep_calibrator",
Copy link
Copy Markdown
Collaborator

@shengliangxu shengliangxu Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a private API, we do not want to put it into __all__.

__all__ will automatically import if user write:

from model_calib import *

without putting into __all__ user can still explicitly import using:

import model_calib._register_fp8_sweep_calibrator
or
from model_calib import _register_fp8_sweep_calibrator

only public APIs should be put into __all__

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
modelopt/torch/quantization/model_calib.py (1)

55-64: Inconsistent API visibility: private function exported in __all__.

_register_fp8_sweep_calibrator uses a private name prefix (_) but is exported in __all__. Based on the past review discussion, this was intended to be a private API. Either:

  1. Remove it from __all__ to keep it truly private, or
  2. Rename it to register_fp8_sweep_calibrator (without underscore) if it should be public
Suggested fix (keep private)
 __all__ = [
     "CalibratorFactory",
-    "_register_fp8_sweep_calibrator",
     "awq",
     "local_hessian_calibrate",
     "max_calibrate",
     "sequential_calibrate",
     "smoothquant",
     "svdquant",
 ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/model_calib.py` around lines 55 - 64, The symbol
_register_fp8_sweep_calibrator is a private function (leading underscore) but is
currently exported in the module-level __all__ list; remove
"_register_fp8_sweep_calibrator" from the __all__ list to keep it private (or
alternatively rename the function to register_fp8_sweep_calibrator if you intend
it to be public). Locate the __all__ definition and either delete the entry
"_register_fp8_sweep_calibrator" from the list or rename the function
declaration and all internal usages to register_fp8_sweep_calibrator to make the
API consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 55-64: The symbol _register_fp8_sweep_calibrator is a private
function (leading underscore) but is currently exported in the module-level
__all__ list; remove "_register_fp8_sweep_calibrator" from the __all__ list to
keep it private (or alternatively rename the function to
register_fp8_sweep_calibrator if you intend it to be public). Locate the __all__
definition and either delete the entry "_register_fp8_sweep_calibrator" from the
list or rename the function declaration and all internal usages to
register_fp8_sweep_calibrator to make the API consistent.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 81d9b849-f6d0-4c17-9f79-51dace44347a

📥 Commits

Reviewing files that changed from the base of the PR and between 0e43a87 and b2a7b32.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/model_calib.py
  • tests/unit/torch/quantization/test_mse_calibrator.py
✅ Files skipped from review due to trivial changes (1)
  • tests/unit/torch/quantization/test_mse_calibrator.py

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants