Exclude small-k and small-n Matmul nodes from Int8 quantization#1256
Exclude small-k and small-n Matmul nodes from Int8 quantization#1256nv-samcheng wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughExclude small MatMul/Gemm ops from quantization when inferred N or K < 16; add a new Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tests/unit/onnx/quantization/test_graph_utils.py (1)
119-182: Add targeted tests forGemm(transB=1)and inference-based exclusion.Nice coverage for MatMul shape-inference. Please add one case validating K extraction when
op="Gemm"withtransB=1, plus one test for_exclude_matmuls_by_inference(sharedinp_bvariable case) to lock in the new runtime-output extension path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/onnx/quantization/test_graph_utils.py` around lines 119 - 182, Add two unit tests in tests/unit/onnx/quantization/test_graph_utils.py: one that constructs a Gemm model with op="Gemm" and attribute transB=1 and asserts _get_inp_b_k_dim on its node returns the correct K (e.g., when B is constant with shape [..., K, N] transposed), and a second test that exercises _exclude_matmuls_by_shape_inference where multiple MatMul/Gemm nodes share the same inp_b Variable (use calibration_shapes only for "A" and provide an output_map or runtime-output scenario so the code path that reads K from runtime-output is used) and assert the expected node id is excluded; reference helpers _make_matmul_model, _get_matmul_nodes, _get_inp_b_k_dim, and _exclude_matmuls_by_shape_inference to locate relevant setup and ensure names/ids match existing tests (e.g., "MatMul_0").
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/quantization/graph_utils.py`:
- Around line 1235-1261: The _get_inp_b_k_dim function currently always reads K
from axis -2 which is wrong for Gemm when transB=1; update _get_inp_b_k_dim to
detect transB (default 0 for MatMul) from the node (check for attribute "transB"
on matmul_node) and compute k_axis = -1 if transB > 0 else -2, then use k_axis
when indexing into inp_b.values.shape, inp_b_info.type.tensor_type.shape.dim,
and output_map[inp_b.name].shape so all three fallback paths respect
transposition; also add unit tests that cover Gemm nodes with transB=1 to
prevent regressions.
- Around line 1343-1348: The code adds matmul outputs and second-input Variable
names to model.graph.output without deduplication, which can create duplicate
output names; update the logic (in the block handling matmul_nodes / uses of
matmul_node.outputs[0].name and matmul_node.inputs[1].name) to track
already-added output names (e.g., a set of names) and only call
model.graph.output.extend with onnx.ValueInfoProto for a name if it is not
already present in that set (and add it to the set after extending), ensuring
you still skip Constants by checking isinstance(matmul_node.inputs[1],
Variable).
---
Nitpick comments:
In `@tests/unit/onnx/quantization/test_graph_utils.py`:
- Around line 119-182: Add two unit tests in
tests/unit/onnx/quantization/test_graph_utils.py: one that constructs a Gemm
model with op="Gemm" and attribute transB=1 and asserts _get_inp_b_k_dim on its
node returns the correct K (e.g., when B is constant with shape [..., K, N]
transposed), and a second test that exercises
_exclude_matmuls_by_shape_inference where multiple MatMul/Gemm nodes share the
same inp_b Variable (use calibration_shapes only for "A" and provide an
output_map or runtime-output scenario so the code path that reads K from
runtime-output is used) and assert the expected node id is excluded; reference
helpers _make_matmul_model, _get_matmul_nodes, _get_inp_b_k_dim, and
_exclude_matmuls_by_shape_inference to locate relevant setup and ensure
names/ids match existing tests (e.g., "MatMul_0").
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 3a5d8843-1a90-424d-a931-a88d63dc0fa0
📒 Files selected for processing (2)
modelopt/onnx/quantization/graph_utils.pytests/unit/onnx/quantization/test_graph_utils.py
4deee67 to
4ba5e57
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
modelopt/onnx/quantization/graph_utils.py (1)
1236-1261:⚠️ Potential issue | 🟠 MajorHandle
Gemm.transBand graph-input shapes when deriving K.
_get_inp_b_k_dim()still assumes K is alwaysB[-2]. That is wrong forGemmwithtransB=1, where K comes fromB[-1], so the new small-K filter can exclude or keep Gemms incorrectly. Also, the shape-inference path only looks atvalue_info/output; ifBis a graph input, its inferred shape lives inmodel.graph.input, sosmall_kbecomes undetectable there.Suggested fix
def _get_inp_b_k_dim( matmul_node, value_info_map: dict | None = None, output_map: dict | None = None ): @@ + trans_b = bool(matmul_node.attrs.get("transB", 0)) if matmul_node.op == "Gemm" else False + k_axis = -1 if trans_b else -2 + inp_b = matmul_node.inputs[1] if hasattr(inp_b, "values") and inp_b.values is not None: inp_b_shape = inp_b.values.shape if len(inp_b_shape) >= 2: - return inp_b_shape[-2] + return inp_b_shape[k_axis] if value_info_map is not None: inp_b_info = value_info_map.get(inp_b.name) if inp_b_info: inp_b_dims = inp_b_info.type.tensor_type.shape.dim if len(inp_b_dims) >= 2: - return inp_b_dims[-2].dim_value + return inp_b_dims[k_axis].dim_value if output_map is not None and inp_b.name in output_map: inp_b_out = output_map[inp_b.name] if len(inp_b_out.shape) >= 2: - return inp_b_out.shape[-2] + return inp_b_out.shape[k_axis] return None- value_info_map = {vi.name: vi for vi in model.graph.value_info} + value_info_map = {vi.name: vi for vi in model.graph.input} + value_info_map.update({vi.name: vi for vi in model.graph.value_info}) value_info_map.update({vi.name: vi for vi in model.graph.output})Verify by checking that
transBis still ignored and that the shape-inference map still excludes graph inputs; expected result is no currenttransBhandling in_get_inp_b_k_dim()and nomodel.graph.inputentries invalue_info_map.#!/bin/bash set -euo pipefail echo "== _get_inp_b_k_dim implementation ==" sed -n '1236,1262p' modelopt/onnx/quantization/graph_utils.py echo echo "== shape-inference value_info_map construction ==" sed -n '1296,1300p' modelopt/onnx/quantization/graph_utils.py echo echo "== references/tests mentioning Gemm, transB, or _get_inp_b_k_dim ==" rg -n -C2 '(_get_inp_b_k_dim|transB|Gemm)' \ modelopt/onnx/quantization/graph_utils.py \ tests/unit/onnx/quantization/test_graph_utils.py🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/quantization/graph_utils.py` around lines 1236 - 1261, _get_inp_b_k_dim currently assumes K = B[-2]; update it to handle Gemm nodes with transB=1 by detecting matmul_node.op_type == "Gemm" and reading the transB attribute (treat missing transB as 0) and, when transB==1, return the last dimension (B[-1] / dim_value of last dim) instead of the second-last; additionally, when consulting shapes from value_info_map/output_map, ensure the function also considers graph input shapes (i.e., the model.graph.input entries are included in the shape lookup) so that an input B whose shape comes from model.graph.input is found (either by expanding the value_info_map to include graph inputs before lookup or by checking a provided graph-input map fallback).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/onnx/quantization/graph_utils.py`:
- Around line 1236-1261: _get_inp_b_k_dim currently assumes K = B[-2]; update it
to handle Gemm nodes with transB=1 by detecting matmul_node.op_type == "Gemm"
and reading the transB attribute (treat missing transB as 0) and, when
transB==1, return the last dimension (B[-1] / dim_value of last dim) instead of
the second-last; additionally, when consulting shapes from
value_info_map/output_map, ensure the function also considers graph input shapes
(i.e., the model.graph.input entries are included in the shape lookup) so that
an input B whose shape comes from model.graph.input is found (either by
expanding the value_info_map to include graph inputs before lookup or by
checking a provided graph-input map fallback).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 6751181b-5c0c-4ca1-ad21-1bfdff85960b
📒 Files selected for processing (2)
modelopt/onnx/quantization/graph_utils.pytests/unit/onnx/quantization/test_graph_utils.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/unit/onnx/quantization/test_graph_utils.py
4ba5e57 to
89e6619
Compare
There was a problem hiding this comment.
Review: Exclude small-k and small-n Matmul nodes from Int8 quantization
Good change overall — the motivation is clear and the implementation is clean. A few items to address before merging:
Issues
1. Missing transB handling for Gemm nodes (Medium)
find_nodes_from_matmul_to_exclude collects both MatMul and Gemm nodes (line 1116: node.op in {"MatMul", "Gemm"}). _get_inp_b_k_dim always reads K from axis [-2], which is correct for MatMul ([K, N]) but wrong for Gemm with transB=1 where B is [N, K] and K is at axis [-1]. This could cause:
- False negatives: a Gemm with small K but large N would read N as K and skip exclusion
- False positives: a Gemm with large K but small N would read N as K and exclude incorrectly
Suggested fix: detect transB attribute on the node and set k_axis = -1 if transB > 0 else -2, then use k_axis across all three fallback paths in _get_inp_b_k_dim.
2. Small-gemm check applies unconditionally for all quantize modes (Low-Medium)
The new small-gemm check fires for all invocations of find_nodes_from_matmul_to_exclude, but the threshold _MIN_MATMUL_DIM_INT8 = 16 is specifically for INT8 (as the name implies). The calling context may invoke this for FP8 quantization too. Consider either:
- Gating the check behind a
quantize_modeparameter (similar to howfind_nodes_from_convs_to_excludedoes it) - Or documenting explicitly that this is intentionally applied to all modes
3. Tests only cover shape-inference path (Low)
All new tests exercise _exclude_matmuls_by_shape_inference. There are no tests for _exclude_matmuls_by_inference (the runtime inference path). The runtime path has the same logic but uses output_map — adding at least one test would increase confidence.
Signed-off-by: samcheng <samcheng@nvidia.com>
Signed-off-by: samcheng <samcheng@nvidia.com>
- Honor Gemm transB in _get_inp_b_k_dim so K is read from B[-1] when transB=1 (B is [N, K]) and from B[-2] otherwise. - Include model.graph.input in value_info_map so B that comes from a graph input is visible to the small-K check. - Apply the small-K/N (<16) exclusion to both INT8 and FP8 (Tensor Core kernels need K,N >= 16 in both modes); rename _MIN_MATMUL_DIM_INT8 to _MIN_MATMUL_DIM. - Add tests for Gemm transB=0/1 (constant and runtime), graph-input B, the runtime-inference path (_exclude_matmuls_by_inference), and the shared-B dedup in added graph outputs. Signed-off-by: samcheng <samcheng@nvidia.com>
89e6619 to
29c5b98
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/unit/onnx/quantization/test_graph_utils.py (1)
116-119: Consider consolidating duplicated node-filter helpers.
_get_matmul_nodesand_get_nodes_by_opcan be unified into one helper to reduce maintenance overhead.♻️ Optional refactor
-def _get_matmul_nodes(model): - """Import an ONNX model and return its MatMul gs.Nodes.""" - graph = gs.import_onnx(model) - return [n for n in graph.nodes if n.op == "MatMul"] +def _get_nodes_by_op(model, op): + graph = gs.import_onnx(model) + return [n for n in graph.nodes if n.op == op]- nodes = _get_matmul_nodes(model) + nodes = _get_nodes_by_op(model, "MatMul")Also applies to: 207-209
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/onnx/quantization/test_graph_utils.py` around lines 116 - 119, There are two duplicated helpers (_get_matmul_nodes and _get_nodes_by_op); replace them with a single reusable function (e.g., _get_nodes_by_op(model_or_graph, op_name)) that imports the ONNX model via gs.import_onnx when given a model and returns nodes filtered by node.op == op_name; update all call sites that used _get_matmul_nodes to call the new helper with op_name="MatMul" and update any usages at lines referenced (including the other occurrence around 207-209) to avoid duplication and keep behavior identical.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/unit/onnx/quantization/test_graph_utils.py`:
- Around line 116-119: There are two duplicated helpers (_get_matmul_nodes and
_get_nodes_by_op); replace them with a single reusable function (e.g.,
_get_nodes_by_op(model_or_graph, op_name)) that imports the ONNX model via
gs.import_onnx when given a model and returns nodes filtered by node.op ==
op_name; update all call sites that used _get_matmul_nodes to call the new
helper with op_name="MatMul" and update any usages at lines referenced
(including the other occurrence around 207-209) to avoid duplication and keep
behavior identical.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 8e2ef918-d589-475e-8e99-f5d9b26b10e5
📒 Files selected for processing (3)
CHANGELOG.rstmodelopt/onnx/quantization/graph_utils.pytests/unit/onnx/quantization/test_graph_utils.py
✅ Files skipped from review due to trivial changes (1)
- CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/onnx/quantization/graph_utils.py
Drop _get_matmul_nodes and route all call sites through _get_nodes_by_op(model, "MatMul") to remove a duplicated helper. Signed-off-by: samcheng <samcheng@nvidia.com>
What does this PR do?
Exclude small-dimension MatMul nodes from INT8 quantization. MatMuls with N or K < 16 cannot efficiently use INT8, causing performance regressions.
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
Bug Fixes
Tests
Documentation