Add FP8 MHA quantization support for HuggingFace ViT#1289
Add FP8 MHA quantization support for HuggingFace ViT#1289
Conversation
Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar transformer vision models) when exported to ONNX with FP8 Q/DQ. - fp8_exporter: rewrite attention-scaling Mul and K Transpose to the Q-side so DQ feeds MatMul directly, pre-transpose weight constants, insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype now matches the graph's float dtype to keep strongly-typed builds consistent. - onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16 inserts around Q/DQ by rewriting scale initializers to FP16, so TRT fuses DQ into the downstream GEMM/MatMul kernel. - torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native input dtype so no Cast is injected between graph and Q/DQ. - torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry so LayerNorm output quantizers are honored. - torch/quantization/plugins/huggingface: skip attention wrappers whose children are also "*Attention" to avoid double-patching eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention). Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8 config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer, disabled input quantizers on LayerNorm-followed layers, and *_bmm_quantizer entries) plus accuracy + TRT-latency comparison against an FP16 baseline. Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1): - Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs 80.96% / 95.44% (torch FP8) — -0.20% / -0.06% - TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
📝 WalkthroughWalkthroughThe changes introduce an end-to-end FP8 quantization workflow for Vision Transformers with ONNX export and TensorRT benchmarking. ONNX graph optimizations reduce redundant operations (weight transpose folding, cast elimination), FP8 export is simplified to remove unnecessary casting, and infrastructure updates enable LayerNorm quantization and prevent double-patching of nested attention modules. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Script as vit_mha_quantization.py
participant Model as Torch Model
participant ONNX as ONNX Graph
participant TRT as TensorRT
User->>Script: Invoke with FP8 flags
Script->>Model: Load ViT + calibration data
Script->>Model: Apply FP8 quantization<br/>(custom config + LayerNorm)
Script->>ONNX: Export to ONNX<br/>(FP16 baseline + FP8)
ONNX->>ONNX: Optimize graph<br/>(fold transposes, moves, casts)
Script->>ONNX: Inspect QDQ nodes<br/>& MatMul consumption
alt Optional: ONNX PTQ
ONNX->>ONNX: Post-training quantization
end
alt Optional: Accuracy Eval
Script->>Model: Evaluate on ImageNet
Model-->>Script: Top-1/Top-5 metrics
end
alt Optional: TRT Benchmark
ONNX->>TRT: Build engine (FP16/FP8)
Script->>TRT: Run trtexec<br/>(warmup + iterations)
TRT-->>Script: GPU latency & throughput
end
Script->>User: Print comparison table
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/torch_onnx/vit_mha_quantization.py`:
- Around line 267-298: The current loop counts any MatMul consuming a
DequantizeLinear output (matmul_with_qdq) which falsely includes projection/MLP
matmuls; replace this with an attention-specific check: implement a helper
(e.g., is_attention_matmul(node, output_to_node, graph)) and only increment
matmul_with_qdq when that returns true. Make is_attention_matmul examine the
MatMul node name and upstream pattern (check parent ops via output_to_node for
Transpose/Reshape, Softmax, or names containing tokens like "q", "k", "v",
"attn", "score", "softmax") or detect the Q@K^T pattern by verifying one input
path comes from a Transpose of a Q-like tensor and the other from K-like tensor;
for attn@V detect the MatMul consuming Softmax output and a V-like source.
Update the loop that currently inspects node.op_type == "MatMul" and uses
inputs_from_dq to call this helper and only count/print when both QDQ and
attention pattern match.
- Around line 225-230: The export is mutating the live model because
model.float() is in-place, which alters base_model/quantized_model used later;
fix by exporting from a detached copy instead (e.g., create a deep copy of model
with copy.deepcopy(model) and call .float() or .to(torch.float16) on that copy)
so get_onnx_bytes_and_metadata receives a non-mutated model; ensure you import
copy and use the copied instance when calling get_onnx_bytes_and_metadata to
avoid changing base_model/quantized_model before accuracy evaluation.
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 100-108: The code currently uses any(c.op == "MatMul" for c in
candidate.outputs[0].outputs) and then rewires/clears all consumers which breaks
non-MatMul branches; change the logic to require all(c.op == "MatMul" for c in
candidate.outputs[0].outputs) before performing the global rewrite OR,
preferably, only rewrite the specific MatMul edges: iterate
candidate.outputs[0].outputs, for each consumer c with c.op == "MatMul" rewire
that consumer's input to use the transposed/scaled/quantized tensor and leave
other consumers untouched, and do not clear original outputs (update
transpose_to_remove only when all downstream edges have been safely redirected).
Apply the same fix pattern to the other rewrite sites that manipulate
torch_weights, perm, transpose_to_remove, and similar MatMul-aware transforms.
In `@modelopt/onnx/utils.py`:
- Around line 1422-1505: The fold helpers unconditionally convert Q/DQ scale
initializers to FLOAT16 which is invalid for opsets < BASE_MIN_OPSET; update
_scale_fp32_to_fp16, fold_dq_fp32_to_fp16_casts and fold_q_fp16_to_fp32_casts to
guard the mutation by checking get_opset_version(onnx_model) (or the model
passed in) and only perform the FP32→FP16 rewrite when
get_opset_version(onnx_model) >= BASE_MIN_OPSET; if the check fails, skip
mutating initializers and skip folding the cast nodes (i.e., return the model
unchanged or continue without calling _scale_fp32_to_fp16/_bypass_cast_node),
using the existing function names (_scale_fp32_to_fp16,
fold_dq_fp32_to_fp16_casts, fold_q_fp16_to_fp32_casts) and constants
(BASE_MIN_OPSET) to locate where to add the guard.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 005378aa-8fac-4f2d-98a1-55297415cbe3
📒 Files selected for processing (8)
examples/torch_onnx/vit_mha_quantization.pymodelopt/onnx/export/fp8_exporter.pymodelopt/onnx/utils.pymodelopt/torch/_deploy/utils/torch_onnx.pymodelopt/torch/quantization/export_onnx.pymodelopt/torch/quantization/nn/__init__.pymodelopt/torch/quantization/nn/modules/quant_layernorm.pymodelopt/torch/quantization/plugins/huggingface.py
| onnx_bytes, _ = get_onnx_bytes_and_metadata( | ||
| model=model.float(), | ||
| dummy_input=dummy_input, | ||
| weights_dtype=weights_dtype, | ||
| model_name=model_name, | ||
| ) |
There was a problem hiding this comment.
Export on a copy instead of mutating the live model.
nn.Module.float() is in-place, so this changes the same base_model / quantized_model instances that you later reuse for accuracy reporting. That makes the script’s “FP16 vs FP8” comparison misleading and can perturb the calibrated quantized model before evaluation.
💡 Suggested fix
- onnx_bytes, _ = get_onnx_bytes_and_metadata(
- model=model.float(),
+ export_model = copy.deepcopy(model).float()
+ onnx_bytes, _ = get_onnx_bytes_and_metadata(
+ model=export_model,
dummy_input=dummy_input,
weights_dtype=weights_dtype,
model_name=model_name,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/torch_onnx/vit_mha_quantization.py` around lines 225 - 230, The
export is mutating the live model because model.float() is in-place, which
alters base_model/quantized_model used later; fix by exporting from a detached
copy instead (e.g., create a deep copy of model with copy.deepcopy(model) and
call .float() or .to(torch.float16) on that copy) so get_onnx_bytes_and_metadata
receives a non-mutated model; ensure you import copy and use the copied instance
when calling get_onnx_bytes_and_metadata to avoid changing
base_model/quantized_model before accuracy evaluation.
| for node in graph.node: | ||
| if node.op_type == "QuantizeLinear": | ||
| qdq_count += 1 | ||
| elif node.op_type == "DequantizeLinear": | ||
| dq_count += 1 | ||
| elif node.op_type == "MatMul": | ||
| matmul_count += 1 | ||
| # Check if inputs come from DequantizeLinear | ||
| inputs_from_dq = [] | ||
| for inp in node.input: | ||
| if inp in output_to_node: | ||
| parent = output_to_node[inp] | ||
| if parent.op_type == "DequantizeLinear": | ||
| inputs_from_dq.append(inp) | ||
| if len(inputs_from_dq) > 0: | ||
| matmul_with_qdq += 1 | ||
| print(f" MatMul '{node.name}' has {len(inputs_from_dq)} QDQ input(s)") | ||
|
|
||
| print("\n Summary:") | ||
| print(f" QuantizeLinear nodes: {qdq_count}") | ||
| print(f" DequantizeLinear nodes: {dq_count}") | ||
| print(f" Total MatMul nodes: {matmul_count}") | ||
| print(f" MatMul with QDQ inputs: {matmul_with_qdq}") | ||
|
|
||
| # Check attention-specific MatMuls (Q@K^T and attn@V) | ||
| # In ViT-base with 12 layers, we expect 24 attention MatMuls (2 per layer) | ||
| # plus linear projection MatMuls | ||
| print("\n For ViT attention:") | ||
| print(" Expected attention MatMuls with QDQ: ~24 (12 layers x 2 BMMs)") | ||
| print(f" Found MatMuls with QDQ inputs: {matmul_with_qdq}") | ||
|
|
||
| return qdq_count, dq_count, matmul_count, matmul_with_qdq |
There was a problem hiding this comment.
This QDQ coverage check is not attention-specific.
matmul_with_qdq increments for any MatMul that consumes a DequantizeLinear output, including plain projection / MLP matmuls after weight compression. That means the “~24 attention MatMuls” check can still look healthy even when Q/K/V or softmax QDQ placement is wrong. Please key this off the actual attention pattern instead of all MatMul + DQ pairs.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/torch_onnx/vit_mha_quantization.py` around lines 267 - 298, The
current loop counts any MatMul consuming a DequantizeLinear output
(matmul_with_qdq) which falsely includes projection/MLP matmuls; replace this
with an attention-specific check: implement a helper (e.g.,
is_attention_matmul(node, output_to_node, graph)) and only increment
matmul_with_qdq when that returns true. Make is_attention_matmul examine the
MatMul node name and upstream pattern (check parent ops via output_to_node for
Transpose/Reshape, Softmax, or names containing tokens like "q", "k", "v",
"attn", "score", "softmax") or detect the Q@K^T pattern by verifying one input
path comes from a Transpose of a Q-like tensor and the other from K-like tensor;
for attn@V detect the MatMul consuming Softmax output and a V-like source.
Update the loop that currently inspects node.op_type == "MatMul" and uses
inputs_from_dq to call this helper and only count/print when both QDQ and
attention pattern match.
| if any(c.op == "MatMul" for c in candidate.outputs[0].outputs): | ||
| perm = candidate.attrs.get("perm", None) | ||
| torch_weights = ( | ||
| torch_weights.permute(*perm).contiguous() | ||
| if perm is not None | ||
| else torch_weights.T.contiguous() | ||
| ) | ||
| transpose_to_remove = candidate | ||
| break |
There was a problem hiding this comment.
Restrict these rewrites to pure MatMul fanout.
Each transform fires when it sees any MatMul consumer, then rewires all consumers and clears the original outputs. If one of these tensors also feeds a non-MatMul branch, that branch will now observe the untransposed / unscaled / quantized value instead of the original one. Please either require all(...) consumers to match the intended pattern or only rewrite the specific MatMul edges.
Also applies to: 128-137, 196-220, 255-275, 334-339
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/export/fp8_exporter.py` around lines 100 - 108, The code
currently uses any(c.op == "MatMul" for c in candidate.outputs[0].outputs) and
then rewires/clears all consumers which breaks non-MatMul branches; change the
logic to require all(c.op == "MatMul" for c in candidate.outputs[0].outputs)
before performing the global rewrite OR, preferably, only rewrite the specific
MatMul edges: iterate candidate.outputs[0].outputs, for each consumer c with
c.op == "MatMul" rewire that consumer's input to use the
transposed/scaled/quantized tensor and leave other consumers untouched, and do
not clear original outputs (update transpose_to_remove only when all downstream
edges have been safely redirected). Apply the same fix pattern to the other
rewrite sites that manipulate torch_weights, perm, transpose_to_remove, and
similar MatMul-aware transforms.
| def _scale_fp32_to_fp16(scale_init: onnx.TensorProto) -> None: | ||
| """Convert a scalar Q/DQ scale initializer in-place from FP32 to FP16.""" | ||
| import numpy as np | ||
|
|
||
| if scale_init.data_type != onnx.TensorProto.FLOAT: | ||
| return | ||
| scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32) | ||
| if not scale_data.size: | ||
| scale_data = np.array(scale_init.float_data, dtype=np.float32) | ||
| scale_init.data_type = onnx.TensorProto.FLOAT16 | ||
| scale_init.raw_data = scale_data.astype(np.float16).tobytes() | ||
| del scale_init.float_data[:] | ||
|
|
||
|
|
||
| def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | ||
| """Remove ``DQ → Cast(FP32→FP16)`` patterns inserted by ``convert_float_to_float16``. | ||
|
|
||
| The DQ scale is rewritten to FP16 so DQ natively produces FP16 output, enabling | ||
| TRT to fuse DQ directly into the downstream compute op. | ||
| """ | ||
| producer_map = {out: node for node in onnx_model.graph.node for out in node.output} | ||
| initializers = {init.name: init for init in onnx_model.graph.initializer} | ||
|
|
||
| to_remove = [] | ||
| for node in onnx_model.graph.node: | ||
| if node.op_type != "Cast": | ||
| continue | ||
| cast_to = next((a.i for a in node.attribute if a.name == "to"), None) | ||
| if cast_to != onnx.TensorProto.FLOAT16: | ||
| continue | ||
| producer = producer_map.get(node.input[0]) | ||
| if producer is None or producer.op_type not in _DQ_OPS: | ||
| continue | ||
|
|
||
| if len(producer.input) >= 2 and producer.input[1] in initializers: | ||
| _scale_fp32_to_fp16(initializers[producer.input[1]]) | ||
|
|
||
| _bypass_cast_node(onnx_model, node) | ||
| to_remove.append(node) | ||
|
|
||
| for vi in onnx_model.graph.value_info: | ||
| if vi.name == producer.output[0]: | ||
| vi.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16 | ||
| break | ||
|
|
||
| logger.debug(f"Folded {len(to_remove)} DQ -> Cast(FP32->FP16) patterns") | ||
| for node in to_remove: | ||
| onnx_model.graph.node.remove(node) | ||
| return onnx_model | ||
|
|
||
|
|
||
| def fold_q_fp16_to_fp32_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | ||
| """Remove ``Cast(FP16→FP32) → Q`` patterns inserted by ``convert_float_to_float16``. | ||
|
|
||
| The Q scale is rewritten to FP16 so Q consumes the FP16 graph directly. | ||
| """ | ||
| consumer_map: dict[str, list[onnx.NodeProto]] = {} | ||
| for node in onnx_model.graph.node: | ||
| for inp in node.input: | ||
| consumer_map.setdefault(inp, []).append(node) | ||
| initializers = {init.name: init for init in onnx_model.graph.initializer} | ||
|
|
||
| to_remove = [] | ||
| for node in onnx_model.graph.node: | ||
| if node.op_type != "Cast": | ||
| continue | ||
| cast_to = next((a.i for a in node.attribute if a.name == "to"), None) | ||
| if cast_to != onnx.TensorProto.FLOAT: | ||
| continue | ||
| consumers = consumer_map.get(node.output[0], []) | ||
| if not consumers or not all(c.op_type in _Q_OPS for c in consumers): | ||
| continue | ||
|
|
||
| for q_node in consumers: | ||
| if len(q_node.input) >= 2 and q_node.input[1] in initializers: | ||
| _scale_fp32_to_fp16(initializers[q_node.input[1]]) | ||
|
|
||
| _bypass_cast_node(onnx_model, node) | ||
| to_remove.append(node) | ||
|
|
||
| logger.debug(f"Folded {len(to_remove)} Cast(FP16->FP32) -> Q patterns") | ||
| for node in to_remove: | ||
| onnx_model.graph.node.remove(node) | ||
| return onnx_model |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
What is the minimum ONNX opset version that allows FLOAT16 scale tensors for QuantizeLinear and DequantizeLinear?
💡 Result:
The minimum ONNX opset version that allows FLOAT16 scale tensors for both QuantizeLinear and DequantizeLinear is 19.
Citations:
- 1: https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html
- 2: https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html
- 3: https://github.com/onnx/onnx/releases/tag/v1.16.0
🏁 Script executed:
# Check for BASE_MIN_OPSET and get_opset_version in the file
rg -n "BASE_MIN_OPSET|get_opset_version" modelopt/onnx/utils.py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 335
🏁 Script executed:
# Check the actual implementation of fold_dq_fp32_to_fp16_casts and fold_q_fp16_to_fp32_casts for guards
sed -n '1422,1505p' modelopt/onnx/utils.py | cat -nRepository: NVIDIA/Model-Optimizer
Length of output: 4025
🏁 Script executed:
# Search for where these fold functions are called
rg -n "fold_dq_fp32_to_fp16_casts|fold_q_fp16_to_fp32_casts" modelopt/onnx/utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 234
🏁 Script executed:
# Look for opset version checks near the fold function calls
rg -B 10 -A 10 "fold_dq_fp32_to_fp16_casts|fold_q_fp16_to_fp32_casts" modelopt/onnx/utils.py | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 1836
🏁 Script executed:
# Search for calls to these fold functions throughout the codebase
rg -n "fold_dq_fp32_to_fp16_casts|fold_q_fp16_to_fp32_casts" modelopt/ --type py -B 5 -A 5Repository: NVIDIA/Model-Optimizer
Length of output: 3546
🏁 Script executed:
# Also check for the convert_float_to_float16 function mentioned in the docstring
rg -n "convert_float_to_float16" modelopt/onnx/utils.py -B 5 -A 20Repository: NVIDIA/Model-Optimizer
Length of output: 2454
🏁 Script executed:
# Check the context where fold functions are called in torch_onnx.py
sed -n '610,640p' modelopt/torch/_deploy/utils/torch_onnx.py | cat -nRepository: NVIDIA/Model-Optimizer
Length of output: 1714
🏁 Script executed:
# Look for opset validation in the export_to_onnx function or its caller
grep -n "def export_to_onnx\|opset\|BASE_MIN_OPSET" modelopt/torch/_deploy/utils/torch_onnx.py | head -30Repository: NVIDIA/Model-Optimizer
Length of output: 232
🏁 Script executed:
# Check what the is_fp8_quantized function does
rg -n "def is_fp8_quantized" modelopt/ -A 10 --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 1147
🏁 Script executed:
# Check the convert_float_to_float16 function and how it's used
rg -n "def convert_float_to_float16\|convert_float_to_float16" modelopt/onnx/utils.py -B 3 -A 15 | head -60Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check for any opset validation in export_to_onnx before the fold calls
sed -n '450,630p' modelopt/torch/_deploy/utils/torch_onnx.py | grep -n "opset\|get_opset_version\|BASE_MIN_OPSET"Repository: NVIDIA/Model-Optimizer
Length of output: 230
🏁 Script executed:
# Look at DEFAULT_ONNX_OPSET to understand the default opset value
rg -n "DEFAULT_ONNX_OPSET" modelopt/ --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 459
🏁 Script executed:
# Search for convert_float_to_float16 function across the entire codebase
rg -n "convert_float_to_float16" modelopt/ --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 506
🏁 Script executed:
# Check what the default export behavior does - look at complete export_to_onnx signature and opset handling
sed -n '450,580p' modelopt/torch/_deploy/utils/torch_onnx.py | cat -nRepository: NVIDIA/Model-Optimizer
Length of output: 6658
🏁 Script executed:
# Check the entire export_to_onnx function flow to see if there's any opset validation before fold calls
sed -n '460,650p' modelopt/torch/_deploy/utils/torch_onnx.py | cat -nRepository: NVIDIA/Model-Optimizer
Length of output: 9684
🏁 Script executed:
# Search for any opset validation or assertion related to BASE_MIN_OPSET in the export path
rg -n "BASE_MIN_OPSET" modelopt/torch/_deploy/ --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
Add opset guard to FP16 scale rewrite helpers.
These helpers rewrite Q/DQ scale initializers to FLOAT16 unconditionally, but FLOAT16 scales are only valid from opset 19 onward. When FP8 quantization is enabled, a user can request a lower opset (e.g., onnx_opset=18) via the export API, which will cause these functions to silently produce an invalid model. Check get_opset_version(onnx_model) >= BASE_MIN_OPSET before mutating initializers, or skip the fold operation entirely for lower opsets.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/utils.py` around lines 1422 - 1505, The fold helpers
unconditionally convert Q/DQ scale initializers to FLOAT16 which is invalid for
opsets < BASE_MIN_OPSET; update _scale_fp32_to_fp16, fold_dq_fp32_to_fp16_casts
and fold_q_fp16_to_fp32_casts to guard the mutation by checking
get_opset_version(onnx_model) (or the model passed in) and only perform the
FP32→FP16 rewrite when get_opset_version(onnx_model) >= BASE_MIN_OPSET; if the
check fails, skip mutating initializers and skip folding the cast nodes (i.e.,
return the model unchanged or continue without calling
_scale_fp32_to_fp16/_bypass_cast_node), using the existing function names
(_scale_fp32_to_fp16, fold_dq_fp32_to_fp16_casts, fold_q_fp16_to_fp32_casts) and
constants (BASE_MIN_OPSET) to locate where to add the guard.
| parent_attention_types = { | ||
| _QuantAttention.get_attn_type(module) | ||
| for _, module in model.named_modules() | ||
| if type(module).__name__.endswith("Attention") | ||
| and any( | ||
| child is not module and type(child).__name__.endswith("Attention") | ||
| for _, child in module.named_modules() | ||
| ) | ||
| } |
There was a problem hiding this comment.
Skip nested attention wrappers by instance, not by class.
Once any instance of an attention class lands in parent_attention_types, every module of that class is skipped here. That will miss quantization for models that reuse the same attention class both as a wrapper and as a leaf. Checking module.named_modules() in-loop avoids dropping unrelated instances.
Also applies to: 306-307
cjluo-nv
left a comment
There was a problem hiding this comment.
This PR adds FP8 MHA quantization support for HuggingFace ViT models with ONNX export optimizations. The implementation is well-structured and addresses a real gap (NVBug 6078291). However, there are several issues to address:
Critical issues:
-
No unit tests — This is ~933 lines of new/changed library code across core export paths (
fp8_exporter.py,utils.py,export_onnx.py,huggingface.py) with zero unit tests. The only "test" is the example script which requires GPU, ImageNet data, and TRT. The graph rewrite functions infp8_exporter.py, the cast folding helpers inutils.py, the attention skipping logic inhuggingface.py, and the LayerNorm quantization registration all need unit tests. -
Bare
assertfor runtime validation infp8_exporter.py— the existingasserton QDQ pair validation will be stripped with-O. -
Silent
contextlib.suppress(Exception)in the example — can mask real failures during benchmark parsing.
Minor issues:
4. The _scale_fp32_to_fp16 helper doesn't handle the case where the scale value overflows or underflows to inf/0 in FP16 — this could silently produce bad quantization results for extreme scales.
-
The
_move_mul_before_qdqrewrite assumes a single scalar const Mul for attention scaling; if the model architecture changes, these pattern-matching rewrites could silently become no-ops without any warning. -
The
_insert_qdq_after_softmaxhardcodes scale=1/448 which is correct for E4M3 but should at minimum document why this specific value and that it's tied to the FP8 E4M3 max representable value.
Positive aspects:
- Clean separation of graph rewrites as static methods
- Good docstrings on the new functions
- The parent_attention_types detection for avoiding double-patching is well done
- The LayerNorm registration follows existing patterns exactly
| @@ -78,6 +80,33 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |||
| f"QDQ does not occur in pairs. You reached {dq_op.op}" | |||
There was a problem hiding this comment.
Bug (existing but expanded): This assert will be stripped with -O flag. Per codebase conventions, use raise ValueError(...) or RuntimeError(...) instead for runtime validation:
if dq_op.op != "TRT_FP8DequantizeLinear":
raise RuntimeError(f"QDQ does not occur in pairs. You reached {dq_op.op}")| if candidate.op != "Transpose": | ||
| cast_to_remove = None | ||
| continue | ||
| if any(c.op == "MatMul" for c in candidate.outputs[0].outputs): |
There was a problem hiding this comment.
Potential issue: When dq_op.outputs[0] has multiple consumers and one is Cast followed by Transpose, the loop breaks after the first candidate. If the first consumer is a Cast that doesn't lead to a Transpose, cast_to_remove is set to None and then continue re-enters the loop at the next candidate — but break at line 109 exits regardless. This means if the first candidate is a non-Cast/non-Transpose node, we skip checking remaining consumers. Consider restructuring to check all candidates:
for candidate in list(dq_op.outputs[0].outputs):
if candidate.op == "Cast":
...
if candidate is None:
cast_to_remove = None
continue # <-- this continues to next candidate, good
if candidate.op != "Transpose":
cast_to_remove = None
continue # <-- this also continues, good
...
break # <-- but this break exits even if we didn't find a matchActually looking more carefully, the logic resets cast_to_remove = None when it's not a Transpose, then continues — so it does check the next candidate. But if candidate.op is a Transpose without a MatMul consumer, you break without resetting state. The logic works but is fragile. A comment would help.
| gs.Node( | ||
| op="Transpose", | ||
| name=transpose_node.name + "_moved", | ||
| inputs=[q_input], |
There was a problem hiding this comment.
The hardcoded 1.0/448.0 is correct (FP8 E4M3 max = 448, softmax output ∈ [0,1]), and the docstring explains this well. However, consider defining this as a named constant (e.g., _FP8_E4M3_SOFTMAX_SCALE = 1.0 / 448.0) since the same value is also used in export_onnx.py (export_fp8_mha uses 1.0 for softmax amax which becomes 448.0/1.0 then 1.0/448.0). This would make the connection explicit.
| return | ||
| scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32) | ||
| if not scale_data.size: | ||
| scale_data = np.array(scale_init.float_data, dtype=np.float32) |
There was a problem hiding this comment.
Potential precision issue: _scale_fp32_to_fp16 does not check if the FP32 scale value is outside FP16 representable range. If a scale is very small (< ~5.96e-8 for FP16 subnormals) or very large (> 65504), the conversion will silently produce 0 or inf. Consider adding a warning:
import warnings
fp16_val = scale_data.astype(np.float16)
if np.any(np.isinf(fp16_val)) or np.any(fp16_val == 0) and np.any(scale_data != 0):
logger.warning("Q/DQ scale overflows/underflows when converting to FP16")| def _scale_fp32_to_fp16(scale_init: onnx.TensorProto) -> None: | ||
| """Convert a scalar Q/DQ scale initializer in-place from FP32 to FP16.""" | ||
| import numpy as np | ||
|
|
There was a problem hiding this comment.
Minor: np is already imported at the top of the file. The local import numpy as np is unnecessary.
| ) | ||
| scale = g.op( | ||
| "Constant", | ||
| value_t=torch.tensor(scale_inv, dtype=torch_dtype_map[otype]), # type: ignore[index] |
There was a problem hiding this comment.
Same as above: trt_high_precision_dtype is now unused in _fp8_dequantize — the DQ output type is set to otype directly. The parameter should either be removed or have a deprecation comment explaining it's kept for API compatibility.
| parent_attention_types = { | ||
| _QuantAttention.get_attn_type(module) | ||
| for _, module in model.named_modules() | ||
| if type(module).__name__.endswith("Attention") |
There was a problem hiding this comment.
Good approach for detecting parent attention wrappers! One edge case: named_modules() includes the module itself, so child is not module is the correct check (which you have). However, this only checks one level of nesting. If there's a 3-level hierarchy like OuterAttention > MiddleAttention > InnerAttention, both OuterAttention and MiddleAttention would be in parent_attention_types, and only InnerAttention would be registered. Is that the intended behavior? For ViT it's fine (only 2 levels), but worth a comment.
|
|
||
| from .quant_module import QuantInputBase, QuantModuleRegistry | ||
|
|
||
| QuantModuleRegistry.register({nn.LayerNorm: "nn.LayerNorm"})(QuantInputBase) |
There was a problem hiding this comment.
This follows the exact pattern of quant_batchnorm.py — good. However, unlike BatchNorm which has 1d/2d/3d variants, LayerNorm only has one class, so this is clean. Consider adding a brief module-level docstring noting that this enables LayerNorm output quantizers to be honored during quantization (the file docstring says "Quantized layer normalization module" which is a bit terse).
| @@ -0,0 +1,538 @@ | |||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |||
There was a problem hiding this comment.
Missing unit tests: This 538-line example is the only "test" for the entire PR. The core library changes (graph rewrites in fp8_exporter.py, cast folding in utils.py, attention skipping in huggingface.py, LayerNorm quantization) need proper unit tests that:
- Test
_move_mul_before_qdqwith a synthetic ONNX graph - Test
_move_transpose_before_qdqwith a synthetic ONNX graph - Test
_insert_qdq_after_softmaxwith a synthetic ONNX graph - Test
fold_dq_fp32_to_fp16_castsandfold_q_fp16_to_fp32_casts - Test
parent_attention_typesdetection with a mock ViT-like model - Test that
nn.LayerNormis properly registered inQuantModuleRegistry
These can all be CPU-only tests with small synthetic models/graphs.
|
|
||
|
|
||
| # ────────────────────────────────────────────────────────────────────────────── | ||
| # Step 4: TRT latency benchmarking |
There was a problem hiding this comment.
contextlib.suppress(Exception) is used in multiple places in the benchmark parsing. Per codebase conventions, prefer explicit handling — suppressing all exceptions can mask real bugs. At minimum, suppress only (IndexError, ValueError) which is what you're already doing inline:
with contextlib.suppress(IndexError, ValueError):
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1289 +/- ##
==========================================
- Coverage 75.61% 71.05% -4.57%
==========================================
Files 459 460 +1
Lines 48597 49312 +715
==========================================
- Hits 36747 35039 -1708
- Misses 11850 14273 +2423
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
can you update changelog for 0.45 to mention this new feature details?
Summary
Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar transformer vision models) when exported to ONNX with FP8 Q/DQ.
modelopt/onnx/export/fp8_exporter.py— new post-processing passes: move attention-scalingMuland KTransposeto the Q-side so DQ feeds MatMul directly, pre-transpose constant weights, and insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Softmax Q/DQ scale dtype now matches the graph's float dtype to keep strongly-typed builds consistent.modelopt/onnx/utils.py—fold_dq_fp32_to_fp16_casts/fold_q_fp16_to_fp32_castshelpers that remove the Cast nodesconvert_float_to_float16inserts around Q/DQ (converts scale initializers to FP16) so TRT fuses DQ into the downstream GEMM.modelopt/torch/_deploy/utils/torch_onnx.py— call the fold helpers for FP8-quantized models afterconvert_float_to_float16.modelopt/torch/quantization/export_onnx.py— keep FP8 Q/DQ scale in the native input dtype so no Cast is emitted between graph and Q/DQ.modelopt/torch/quantization/nn/modules/quant_layernorm.py(new) — registernn.LayerNorminQuantModuleRegistryso LayerNorm output quantizers are honored.modelopt/torch/quantization/plugins/huggingface.py— skip*Attentionwrappers whose children are also*Attentionto avoid double-patchingeager_attention_forward(e.g.ViTAttentionvsViTSelfAttention).examples/torch_onnx/vit_mha_quantization.py(new) — end-to-end example: ViT-FP8 config (extendsFP8_DEFAULT_CFGwith LayerNorm output quantizer, disabled redundant input quantizers, and*_bmm_quantizerentries), ONNX export, QDQ-coverage check, and accuracy + TRT-latency comparison vs FP16 baseline.Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1):
Test plan
python -m pytest tests/unit/torch/quantization tests/unit/onnxpython examples/torch_onnx/vit_mha_quantization.py --skip_eval --skip_latency(verifies ONNX export with 96 MatMul-with-QDQ inputs in ViT-base)trtexec --onnx=vit_torch_fp8.onnx --fp8 --stronglyTypedSummary by CodeRabbit
New Features
Improvements