Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Changelog

- Fix Minitron pruning (``mcore_minitron``) for MoE models. Importance estimation hooks were incorrectly registered for MoE modules and NAS step was hanging before this.
- Fix TRT support for remote autotuning in ONNX Autotune from 10.16+ to 10.15+ and fix TRT versioning check to the ``trtexec`` version instead of the TRT Python API when using ``trtexec`` backend.
- Exclude MatMul/Gemm nodes with K or N < 16 from ONNX INT8 and FP8 quantization. Such small-dimension GEMMs cannot efficiently use INT8/FP8 Tensor Cores and the added Q/DQ layers cause perf regressions in TensorRT. Honors Gemm ``transB`` when deriving K.

**Misc**

Expand Down
107 changes: 97 additions & 10 deletions modelopt/onnx/quantization/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,11 +1089,14 @@ def find_nodes_from_matmul_to_exclude(
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
calibration_shapes: str | dict | None = None,
) -> list[str]:
"""Find MatMul nodes that meets gemv condition to exclude.
"""Find MatMul nodes that meet gemv or small-gemm conditions and should be excluded.

Either of m or n in matmul is 1, this matmul cannot utilize
TensorCores. The perf of adding Q/DQ layers is not good in
TRT. Thus, in this case, do not add Q/DQ layers to this matmul.
A MatMul is excluded if either:

- m or n in the output is 1 (GEMV): cannot utilize TensorCores; or
- K or N is smaller than ``_MIN_MATMUL_DIM`` (16): both INT8 and FP8 Tensor Core
kernels need K/N >= 16 to be efficient, and adding Q/DQ layers on such small
GEMMs causes TRT perf regressions.

Args:
onnx_path: Path to the onnx model.
Expand Down Expand Up @@ -1143,6 +1146,10 @@ def find_nodes_from_matmul_to_exclude(


_MIN_CHANNELS_FP8 = 16
# Minimum K/N dim for MatMul/Gemm under INT8 or FP8 quantization. Both INT8 and FP8
# Tensor Core kernels need K/N >= 16 to be efficient; adding Q/DQ layers on smaller
# GEMMs causes TRT perf regressions.
_MIN_MATMUL_DIM = 16


def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
Expand Down Expand Up @@ -1231,10 +1238,47 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
return unsupported_conv_nodes


def _get_inp_b_k_dim(
matmul_node, value_info_map: dict | None = None, output_map: dict | None = None
):
"""Get the K dimension from the second input of a MatMul/Gemm node.

Tries Constant shape first, then falls back to shape inference (value_info_map)
or runtime inference (output_map). For Gemm nodes, honors the ``transB`` attribute:
when ``transB=1``, B has shape ``[N, K]`` so K lives at axis -1; otherwise B is
``[..., K, N]`` and K is at axis -2.

Returns:
The K dimension value, or None if it cannot be determined.
"""
# For Gemm, transB=1 means B is [N, K] (K is last axis); default/MatMul is [K, N].
trans_b = bool(matmul_node.attrs.get("transB", 0)) if matmul_node.op == "Gemm" else False
k_axis = -1 if trans_b else -2

inp_b = matmul_node.inputs[1]
if hasattr(inp_b, "values") and inp_b.values is not None:
inp_b_shape = inp_b.values.shape
if len(inp_b_shape) >= 2:
return inp_b_shape[k_axis]
if value_info_map is not None:
inp_b_info = value_info_map.get(inp_b.name)
if inp_b_info:
inp_b_dims = inp_b_info.type.tensor_type.shape.dim
if len(inp_b_dims) >= 2:
return inp_b_dims[k_axis].dim_value
if output_map is not None and inp_b.name in output_map:
inp_b_out = output_map[inp_b.name]
if len(inp_b_out.shape) >= 2:
return inp_b_out.shape[k_axis]
return None
Comment thread
nv-samcheng marked this conversation as resolved.


def _exclude_matmuls_by_shape_inference(
model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None
model: onnx.ModelProto,
matmul_nodes: list,
calibration_shapes: str | dict | None = None,
) -> list[str]:
"""Use shape inference to find MatMuls with dimension 1."""
"""Use shape inference to find MatMuls with dimension 1 or small K/N."""
# Prepare model for symbolic inference
for graph_input in model.graph.input:
for dim in graph_input.type.tensor_type.shape.dim:
Expand Down Expand Up @@ -1263,7 +1307,10 @@ def _exclude_matmuls_by_shape_inference(
dim.dim_value = new_dim_value

model = infer_shapes(model)
value_info_map = {vi.name: vi for vi in model.graph.value_info}
# Include graph inputs, value_info, and outputs so B that comes from a graph input
# is visible when deriving K.
value_info_map = {vi.name: vi for vi in model.graph.input}
value_info_map.update({vi.name: vi for vi in model.graph.value_info})
value_info_map.update({vi.name: vi for vi in model.graph.output})

nodes_to_exclude = []
Expand All @@ -1280,8 +1327,23 @@ def _exclude_matmuls_by_shape_inference(

if dims[-1].dim_value == 1 or dims[-2].dim_value == 1:
nodes_to_exclude.append(matmul_node.name)
continue
elif len(dims) < 3 and any(out.dim_value == 1 for out in dims):
nodes_to_exclude.append(matmul_node.name)
continue

# Small-gemm check: applies to both INT8 and FP8 quantization.
n_dim = dims[-1].dim_value if len(dims) >= 2 else 0
k_dim = _get_inp_b_k_dim(matmul_node, value_info_map=value_info_map)
small_n = 0 < n_dim < _MIN_MATMUL_DIM
small_k = k_dim is not None and 0 < k_dim < _MIN_MATMUL_DIM

if small_n or small_k:
logger.debug(
f"Excluding small-dim MatMul from quantization: {matmul_node.name} "
f"(N={n_dim}, K={k_dim}, threshold={_MIN_MATMUL_DIM})"
)
nodes_to_exclude.append(matmul_node.name)

return nodes_to_exclude

Expand All @@ -1295,10 +1357,20 @@ def _exclude_matmuls_by_inference(
calibration_data_reader: CalibrationDataReader,
calibration_eps: list[str],
) -> list[str]:
"""Use actual inference to find MatMuls with dimension 1."""
# Add matmul outputs to model outputs
"""Use actual inference to find MatMuls with dimension 1 or small K/N."""
# Add matmul outputs and second-input outputs to model outputs
existing_output_names = {out.name for out in model.graph.output}
for matmul_node in matmul_nodes:
model.graph.output.extend([onnx.ValueInfoProto(name=matmul_node.outputs[0].name)])
out_name = matmul_node.outputs[0].name
if out_name not in existing_output_names:
model.graph.output.extend([onnx.ValueInfoProto(name=out_name)])
existing_output_names.add(out_name)
# Also add second input for K-dimension check (only if it's a Variable, not a Constant)
if isinstance(matmul_node.inputs[1], Variable):
inp_b_name = matmul_node.inputs[1].name
if inp_b_name not in existing_output_names:
model.graph.output.extend([onnx.ValueInfoProto(name=inp_b_name)])
existing_output_names.add(inp_b_name)

output_map = get_extended_model_outputs(
onnx_path,
Expand All @@ -1319,8 +1391,23 @@ def _exclude_matmuls_by_inference(
or matmul_output.shape[-2] == 1
):
nodes_to_exclude.append(matmul_node.name)
continue
elif len(matmul_output.shape) < 3 and any(out == 1 for out in matmul_output.shape):
nodes_to_exclude.append(matmul_node.name)
continue

# Small-gemm check: applies to both INT8 and FP8 quantization.
n_dim = matmul_output.shape[-1] if len(matmul_output.shape) >= 2 else 0
k_dim = _get_inp_b_k_dim(matmul_node, output_map=output_map)
small_n = 0 < n_dim < _MIN_MATMUL_DIM
small_k = k_dim is not None and 0 < k_dim < _MIN_MATMUL_DIM

if small_n or small_k:
logger.debug(
f"Excluding small-dim MatMul from quantization: {matmul_node.name} "
f"(N={n_dim}, K={k_dim}, threshold={_MIN_MATMUL_DIM})"
)
nodes_to_exclude.append(matmul_node.name)

return nodes_to_exclude

Expand Down
Loading