[minor] Add custom calibration backend registry#1281
Conversation
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds a module-level FP8 scale-sweep calibrator extension: a public Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Client
participant Reg as FP8 Registry
participant Cal as mse_calibrate
participant Fac as Calibrator Factory
participant Mod as TensorQuantizer/Module
Client->>Reg: _register_fp8_sweep_calibrator(backend, factory)
Note over Reg: store factory by backend
rect rgba(100,150,240,0.5)
Client->>Cal: mse_calibrate(..., fp8_scale_sweep=True)
Cal->>Mod: enumerate enabled quantizers
Mod->>Cal: provide backend, initial_amax, axis, quant_func
Cal->>Reg: lookup backend
alt backend registered
Reg-->>Cal: factory
Cal->>Fac: call factory(initial_amax, axis, partial(quant_func, quantizer=Mod))
Fac-->>Cal: custom _Calibrator
Cal->>Mod: set module._calibrator = custom _Calibrator
else not registered
Reg-->>Cal: no entry
Cal->>Mod: leave/assign default MseCalibrator
end
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
🧹 Nitpick comments (2)
modelopt/torch/quantization/model_calib.py (1)
64-65: Type hint mismatch with actual usage.The registry is typed as
dict[str, type]but stores callable factory functions, not necessarily types/classes. The docstring correctly describes it as a callable with signature(amax, axis, quant_func).Suggested fix for type accuracy
+from typing import Callable + # Registry for backends that provide a custom calibrator factory for mse_calibrate(). -_FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, type] = {} +_FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, Callable] = {}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/model_calib.py` around lines 64 - 65, The registry _FP8_SWEEP_CALIBRATOR_REGISTRY is annotated as dict[str, type] but actually stores callable factory functions; update its type hint to reflect a callable signature (import Callable/Any from typing) e.g. dict[str, Callable[[float, int, Callable[..., Any]], Any]] or similar to match the described factory signature (amax, axis, quant_func), and adjust any related references (e.g., usages in mse_calibrate) to satisfy the new annotation.tests/unit/torch/quantization/test_mse_calibrator.py (1)
596-617: Consider verifying the factory receives correct arguments.The test verifies the factory is called once, but doesn't validate that the arguments passed to the factory are correct (e.g., that
amaxis a tensor,quant_funcis callable). This is optional since the test already confirms the dispatch path works.Optional enhancement to validate arguments
def my_factory(amax, axis, quant_func): - factory_calls.append(amax) + factory_calls.append((amax, axis, quant_func)) return _RecordingCalibrator(amax=amax, axis=axis, quant_func=quant_func) register_fp8_sweep_calibrator("_test_dispatch", my_factory) self._quantize_and_calibrate("_test_dispatch", fp8_scale_sweep=True) assert len(factory_calls) == 1 + amax, axis, quant_func = factory_calls[0] + assert isinstance(amax, torch.Tensor) + assert callable(quant_func)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/quantization/test_mse_calibrator.py` around lines 596 - 617, Update test_mse_calibrate_dispatches_to_registered_factory to assert the factory receives the expected argument types/values: inside my_factory (registered via register_fp8_sweep_calibrator) capture the passed amax, axis, and quant_func into factory_calls (or a separate list) and add assertions after self._quantize_and_calibrate that the captured amax is a tensor (or nd-array as expected), axis matches the expected axis value, and quant_func is callable; reference the test function name, my_factory, _RecordingCalibrator, and register_fp8_sweep_calibrator when locating and modifying the test.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 64-65: The registry _FP8_SWEEP_CALIBRATOR_REGISTRY is annotated as
dict[str, type] but actually stores callable factory functions; update its type
hint to reflect a callable signature (import Callable/Any from typing) e.g.
dict[str, Callable[[float, int, Callable[..., Any]], Any]] or similar to match
the described factory signature (amax, axis, quant_func), and adjust any related
references (e.g., usages in mse_calibrate) to satisfy the new annotation.
In `@tests/unit/torch/quantization/test_mse_calibrator.py`:
- Around line 596-617: Update
test_mse_calibrate_dispatches_to_registered_factory to assert the factory
receives the expected argument types/values: inside my_factory (registered via
register_fp8_sweep_calibrator) capture the passed amax, axis, and quant_func
into factory_calls (or a separate list) and add assertions after
self._quantize_and_calibrate that the captured amax is a tensor (or nd-array as
expected), axis matches the expected axis value, and quant_func is callable;
reference the test function name, my_factory, _RecordingCalibrator, and
register_fp8_sweep_calibrator when locating and modifying the test.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: f432804b-9113-44fe-aa54-f758151cde31
📒 Files selected for processing (2)
modelopt/torch/quantization/model_calib.pytests/unit/torch/quantization/test_mse_calibrator.py
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1281 +/- ##
==========================================
- Coverage 76.61% 76.15% -0.47%
==========================================
Files 459 459
Lines 49153 49164 +11
==========================================
- Hits 37661 37439 -222
- Misses 11492 11725 +233
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
LGTM! |
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/quantization/model_calib.py (1)
374-379: Validate the factory output before storing it.This path trusts the factory result blindly, but the later calibration flow assumes
_Calibratorbehavior. A quick check here would turn plugin mistakes into an immediate, actionable error instead of a delayedAttributeErrordeeper in calibration.Suggested fix
- module._calibrator = backend_factory( + calibrator = backend_factory( initial_amax, module._calibrator._axis, partial(_mse_quant_func, quantizer=module), ) + if not isinstance(calibrator, _Calibrator): + raise TypeError( + "Registered FP8 sweep calibrator must return _Calibrator, " + f"got {type(calibrator).__name__}" + ) + module._calibrator = calibrator continue🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/model_calib.py` around lines 374 - 379, The code assigns module._calibrator = backend_factory(...) without validating it; change this to check that the returned object implements the expected _Calibrator interface (e.g., has required attributes/methods like _axis and whatever methods the calibration flow expects) after calling backend_factory in the branch that creates a new calibrator for module using _mse_quant_func; if the check fails, raise a descriptive error (including the factory identity and the missing attribute names) so plugin mistakes fail fast instead of causing AttributeError later in the calibration pipeline.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 365-372: The code currently assumes module.backend is a str or
None before looking up _FP8_SWEEP_CALIBRATOR_REGISTRY, but an unhashable/non-str
value can cause a TypeError; update the fp8_scale_sweep branch to first fetch
_backend = getattr(module, "backend", None) and only perform the registry lookup
when isinstance(_backend, str) (otherwise treat backend_factory as None) so that
unhashable values skip the registry lookup and fall back to the default
calibrator (referencing fp8_scale_sweep, module, _backend, and
_FP8_SWEEP_CALIBRATOR_REGISTRY).
---
Nitpick comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 374-379: The code assigns module._calibrator =
backend_factory(...) without validating it; change this to check that the
returned object implements the expected _Calibrator interface (e.g., has
required attributes/methods like _axis and whatever methods the calibration flow
expects) after calling backend_factory in the branch that creates a new
calibrator for module using _mse_quant_func; if the check fails, raise a
descriptive error (including the factory identity and the missing attribute
names) so plugin mistakes fail fast instead of causing AttributeError later in
the calibration pipeline.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: f981ca33-0f21-458e-a8d5-a19da0795815
📒 Files selected for processing (1)
modelopt/torch/quantization/model_calib.py
| _FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, CalibratorFactory] = {} | ||
|
|
||
|
|
||
| def register_fp8_sweep_calibrator(backend: str, calibrator_factory: CalibratorFactory) -> None: |
There was a problem hiding this comment.
do we intend to make it an API that our external users would use?
There was a problem hiding this comment.
Advanced users can use it for new calibration algorithms, though I not sure if we have such use case in external usages.
There was a problem hiding this comment.
OK, then I would recommend that we keep it as a private API, prefix it with _
def _register_fp8_sweep_calibrator
If we we want to expose it as public API, then we also need documentation about it. And would need to at least implement some calibrator_facotry for user to choose.
How do you think?
There was a problem hiding this comment.
Thanks for the suggestion, make the update
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
|
|
||
| __all__ = [ | ||
| "CalibratorFactory", | ||
| "_register_fp8_sweep_calibrator", |
There was a problem hiding this comment.
For a private API, we do not want to put it into __all__.
__all__ will automatically import if user write:
from model_calib import *
without putting into __all__ user can still explicitly import using:
import model_calib._register_fp8_sweep_calibrator
or
from model_calib import _register_fp8_sweep_calibrator
only public APIs should be put into __all__
There was a problem hiding this comment.
🧹 Nitpick comments (1)
modelopt/torch/quantization/model_calib.py (1)
55-64: Inconsistent API visibility: private function exported in__all__.
_register_fp8_sweep_calibratoruses a private name prefix (_) but is exported in__all__. Based on the past review discussion, this was intended to be a private API. Either:
- Remove it from
__all__to keep it truly private, or- Rename it to
register_fp8_sweep_calibrator(without underscore) if it should be publicSuggested fix (keep private)
__all__ = [ "CalibratorFactory", - "_register_fp8_sweep_calibrator", "awq", "local_hessian_calibrate", "max_calibrate", "sequential_calibrate", "smoothquant", "svdquant", ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/model_calib.py` around lines 55 - 64, The symbol _register_fp8_sweep_calibrator is a private function (leading underscore) but is currently exported in the module-level __all__ list; remove "_register_fp8_sweep_calibrator" from the __all__ list to keep it private (or alternatively rename the function to register_fp8_sweep_calibrator if you intend it to be public). Locate the __all__ definition and either delete the entry "_register_fp8_sweep_calibrator" from the list or rename the function declaration and all internal usages to register_fp8_sweep_calibrator to make the API consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 55-64: The symbol _register_fp8_sweep_calibrator is a private
function (leading underscore) but is currently exported in the module-level
__all__ list; remove "_register_fp8_sweep_calibrator" from the __all__ list to
keep it private (or alternatively rename the function to
register_fp8_sweep_calibrator if you intend it to be public). Locate the __all__
definition and either delete the entry "_register_fp8_sweep_calibrator" from the
list or rename the function declaration and all internal usages to
register_fp8_sweep_calibrator to make the API consistent.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 81d9b849-f6d0-4c17-9f79-51dace44347a
📒 Files selected for processing (2)
modelopt/torch/quantization/model_calib.pytests/unit/torch/quantization/test_mse_calibrator.py
✅ Files skipped from review due to trivial changes (1)
- tests/unit/torch/quantization/test_mse_calibrator.py
What does this PR do?
Type of change: ?
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
Tests