Skip to content

Add FP8 MHA quantization support for HuggingFace ViT#1289

Open
ajrasane wants to merge 1 commit intomainfrom
ajrasane/mha_quantization
Open

Add FP8 MHA quantization support for HuggingFace ViT#1289
ajrasane wants to merge 1 commit intomainfrom
ajrasane/mha_quantization

Conversation

@ajrasane
Copy link
Copy Markdown
Contributor

@ajrasane ajrasane commented Apr 17, 2026

Summary

Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar transformer vision models) when exported to ONNX with FP8 Q/DQ.

  • modelopt/onnx/export/fp8_exporter.py — new post-processing passes: move attention-scaling Mul and K Transpose to the Q-side so DQ feeds MatMul directly, pre-transpose constant weights, and insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Softmax Q/DQ scale dtype now matches the graph's float dtype to keep strongly-typed builds consistent.
  • modelopt/onnx/utils.pyfold_dq_fp32_to_fp16_casts / fold_q_fp16_to_fp32_casts helpers that remove the Cast nodes convert_float_to_float16 inserts around Q/DQ (converts scale initializers to FP16) so TRT fuses DQ into the downstream GEMM.
  • modelopt/torch/_deploy/utils/torch_onnx.py — call the fold helpers for FP8-quantized models after convert_float_to_float16.
  • modelopt/torch/quantization/export_onnx.py — keep FP8 Q/DQ scale in the native input dtype so no Cast is emitted between graph and Q/DQ.
  • modelopt/torch/quantization/nn/modules/quant_layernorm.py (new) — register nn.LayerNorm in QuantModuleRegistry so LayerNorm output quantizers are honored.
  • modelopt/torch/quantization/plugins/huggingface.py — skip *Attention wrappers whose children are also *Attention to avoid double-patching eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention).
  • examples/torch_onnx/vit_mha_quantization.py (new) — end-to-end example: ViT-FP8 config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer, disabled redundant input quantizers, and *_bmm_quantizer entries), ONNX export, QDQ-coverage check, and accuracy + TRT-latency comparison vs FP16 baseline.

Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1):

Metric FP16 baseline Torch FP8 Δ
Top-1 (5k ImageNet val) 81.16% 80.96% −0.20%
Top-5 95.50% 95.44% −0.06%
TRT GPU compute 0.721 ms 0.646 ms 1.12× speedup

Test plan

  • CPU unit tests: python -m pytest tests/unit/torch/quantization tests/unit/onnx
  • Run the example: python examples/torch_onnx/vit_mha_quantization.py --skip_eval --skip_latency (verifies ONNX export with 96 MatMul-with-QDQ inputs in ViT-base)
  • trtexec FP8 strongly-typed build succeeds: trtexec --onnx=vit_torch_fp8.onnx --fp8 --stronglyTyped
  • Accuracy within ~0.3% of FP16 baseline on ImageNet-1k validation subset

Summary by CodeRabbit

  • New Features

    • Added example demonstrating FP8 quantization of Vision Transformer with ONNX export and TensorRT benchmarking capabilities.
    • Enabled LayerNorm quantization support.
  • Improvements

    • Enhanced FP8 ONNX export with graph optimizations for better performance.
    • Simplified FP8 quantization logic by removing unnecessary casting overhead.
    • Improved HuggingFace attention module handling to prevent redundant patching.

Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar
transformer vision models) when exported to ONNX with FP8 Q/DQ.

- fp8_exporter: rewrite attention-scaling Mul and K Transpose to the
  Q-side so DQ feeds MatMul directly, pre-transpose weight constants,
  insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype
  now matches the graph's float dtype to keep strongly-typed builds
  consistent.
- onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16
  inserts around Q/DQ by rewriting scale initializers to FP16, so TRT
  fuses DQ into the downstream GEMM/MatMul kernel.
- torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native
  input dtype so no Cast is injected between graph and Q/DQ.
- torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry
  so LayerNorm output quantizers are honored.
- torch/quantization/plugins/huggingface: skip attention wrappers whose
  children are also "*Attention" to avoid double-patching
  eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention).

Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8
config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer,
disabled input quantizers on LayerNorm-followed layers, and
*_bmm_quantizer entries) plus accuracy + TRT-latency comparison
against an FP16 baseline.

Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1):
- Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs
  80.96% / 95.44% (torch FP8) — -0.20% / -0.06%
- TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup

Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
@ajrasane ajrasane requested review from a team as code owners April 17, 2026 20:14
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 17, 2026

📝 Walkthrough

Walkthrough

The changes introduce an end-to-end FP8 quantization workflow for Vision Transformers with ONNX export and TensorRT benchmarking. ONNX graph optimizations reduce redundant operations (weight transpose folding, cast elimination), FP8 export is simplified to remove unnecessary casting, and infrastructure updates enable LayerNorm quantization and prevent double-patching of nested attention modules.

Changes

Cohort / File(s) Summary
New Example Script
examples/torch_onnx/vit_mha_quantization.py
Complete executable script for ViT FP8 quantization featuring customized FP8 config, ONNX export with QDQ node inspection, optional ONNX PTQ, ImageNet accuracy evaluation, and TensorRT latency benchmarking via trtexec.
ONNX Graph Optimization
modelopt/onnx/export/fp8_exporter.py, modelopt/onnx/utils.py, modelopt/torch/_deploy/utils/torch_onnx.py
FP8-specific graph rewrites: weight transpose folding, scalar-constant multiplication movement before quantization, softmax quantization insertion, and cast folding around Q/DQ patterns. New utilities detect and convert FP32/FP16 scales and bypass redundant casts.
FP8 Export Simplification
modelopt/torch/quantization/export_onnx.py
Removed conditional Cast node insertion from FP8 quantize/dequantize helpers; scale constants and outputs now use native dtypes without intermediate casting.
LayerNorm Quantization Support
modelopt/torch/quantization/nn/modules/quant_layernorm.py, modelopt/torch/quantization/nn/__init__.py
New module registers torch.nn.LayerNorm with QuantInputBase via QuantModuleRegistry, enabling quantization support for layer normalization. Exposed via package-level wildcard import.
Attention Module Patching
modelopt/torch/quantization/plugins/huggingface.py
Added detection of parent-child attention module relationships to skip double-wrapping; nested attention modules within parent attention modules are excluded from quantization patching.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Script as vit_mha_quantization.py
    participant Model as Torch Model
    participant ONNX as ONNX Graph
    participant TRT as TensorRT

    User->>Script: Invoke with FP8 flags
    Script->>Model: Load ViT + calibration data
    Script->>Model: Apply FP8 quantization<br/>(custom config + LayerNorm)
    Script->>ONNX: Export to ONNX<br/>(FP16 baseline + FP8)
    ONNX->>ONNX: Optimize graph<br/>(fold transposes, moves, casts)
    Script->>ONNX: Inspect QDQ nodes<br/>& MatMul consumption
    alt Optional: ONNX PTQ
        ONNX->>ONNX: Post-training quantization
    end
    alt Optional: Accuracy Eval
        Script->>Model: Evaluate on ImageNet
        Model-->>Script: Top-1/Top-5 metrics
    end
    alt Optional: TRT Benchmark
        ONNX->>TRT: Build engine (FP16/FP8)
        Script->>TRT: Run trtexec<br/>(warmup + iterations)
        TRT-->>Script: GPU latency & throughput
    end
    Script->>User: Print comparison table
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 79.41% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately and concisely summarizes the main objective of the PR: adding FP8 quantization support for multi-head attention (MHA) in HuggingFace ViT models.
Security Anti-Patterns ✅ Passed The PR does not introduce any security anti-patterns defined in SECURITY.md. Subprocess calls use list-based arguments without shell=True, model loading correctly defaults trust_remote_code to False, and no dangerous functions are called on untrusted input.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch ajrasane/mha_quantization

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1289/

Built to branch gh-pages at 2026-04-17 20:19 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/torch_onnx/vit_mha_quantization.py`:
- Around line 267-298: The current loop counts any MatMul consuming a
DequantizeLinear output (matmul_with_qdq) which falsely includes projection/MLP
matmuls; replace this with an attention-specific check: implement a helper
(e.g., is_attention_matmul(node, output_to_node, graph)) and only increment
matmul_with_qdq when that returns true. Make is_attention_matmul examine the
MatMul node name and upstream pattern (check parent ops via output_to_node for
Transpose/Reshape, Softmax, or names containing tokens like "q", "k", "v",
"attn", "score", "softmax") or detect the Q@K^T pattern by verifying one input
path comes from a Transpose of a Q-like tensor and the other from K-like tensor;
for attn@V detect the MatMul consuming Softmax output and a V-like source.
Update the loop that currently inspects node.op_type == "MatMul" and uses
inputs_from_dq to call this helper and only count/print when both QDQ and
attention pattern match.
- Around line 225-230: The export is mutating the live model because
model.float() is in-place, which alters base_model/quantized_model used later;
fix by exporting from a detached copy instead (e.g., create a deep copy of model
with copy.deepcopy(model) and call .float() or .to(torch.float16) on that copy)
so get_onnx_bytes_and_metadata receives a non-mutated model; ensure you import
copy and use the copied instance when calling get_onnx_bytes_and_metadata to
avoid changing base_model/quantized_model before accuracy evaluation.

In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 100-108: The code currently uses any(c.op == "MatMul" for c in
candidate.outputs[0].outputs) and then rewires/clears all consumers which breaks
non-MatMul branches; change the logic to require all(c.op == "MatMul" for c in
candidate.outputs[0].outputs) before performing the global rewrite OR,
preferably, only rewrite the specific MatMul edges: iterate
candidate.outputs[0].outputs, for each consumer c with c.op == "MatMul" rewire
that consumer's input to use the transposed/scaled/quantized tensor and leave
other consumers untouched, and do not clear original outputs (update
transpose_to_remove only when all downstream edges have been safely redirected).
Apply the same fix pattern to the other rewrite sites that manipulate
torch_weights, perm, transpose_to_remove, and similar MatMul-aware transforms.

In `@modelopt/onnx/utils.py`:
- Around line 1422-1505: The fold helpers unconditionally convert Q/DQ scale
initializers to FLOAT16 which is invalid for opsets < BASE_MIN_OPSET; update
_scale_fp32_to_fp16, fold_dq_fp32_to_fp16_casts and fold_q_fp16_to_fp32_casts to
guard the mutation by checking get_opset_version(onnx_model) (or the model
passed in) and only perform the FP32→FP16 rewrite when
get_opset_version(onnx_model) >= BASE_MIN_OPSET; if the check fails, skip
mutating initializers and skip folding the cast nodes (i.e., return the model
unchanged or continue without calling _scale_fp32_to_fp16/_bypass_cast_node),
using the existing function names (_scale_fp32_to_fp16,
fold_dq_fp32_to_fp16_casts, fold_q_fp16_to_fp32_casts) and constants
(BASE_MIN_OPSET) to locate where to add the guard.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 005378aa-8fac-4f2d-98a1-55297415cbe3

📥 Commits

Reviewing files that changed from the base of the PR and between e4b054b and d6533ac.

📒 Files selected for processing (8)
  • examples/torch_onnx/vit_mha_quantization.py
  • modelopt/onnx/export/fp8_exporter.py
  • modelopt/onnx/utils.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • modelopt/torch/quantization/export_onnx.py
  • modelopt/torch/quantization/nn/__init__.py
  • modelopt/torch/quantization/nn/modules/quant_layernorm.py
  • modelopt/torch/quantization/plugins/huggingface.py

Comment on lines +225 to +230
onnx_bytes, _ = get_onnx_bytes_and_metadata(
model=model.float(),
dummy_input=dummy_input,
weights_dtype=weights_dtype,
model_name=model_name,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Export on a copy instead of mutating the live model.

nn.Module.float() is in-place, so this changes the same base_model / quantized_model instances that you later reuse for accuracy reporting. That makes the script’s “FP16 vs FP8” comparison misleading and can perturb the calibrated quantized model before evaluation.

💡 Suggested fix
-    onnx_bytes, _ = get_onnx_bytes_and_metadata(
-        model=model.float(),
+    export_model = copy.deepcopy(model).float()
+    onnx_bytes, _ = get_onnx_bytes_and_metadata(
+        model=export_model,
         dummy_input=dummy_input,
         weights_dtype=weights_dtype,
         model_name=model_name,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/vit_mha_quantization.py` around lines 225 - 230, The
export is mutating the live model because model.float() is in-place, which
alters base_model/quantized_model used later; fix by exporting from a detached
copy instead (e.g., create a deep copy of model with copy.deepcopy(model) and
call .float() or .to(torch.float16) on that copy) so get_onnx_bytes_and_metadata
receives a non-mutated model; ensure you import copy and use the copied instance
when calling get_onnx_bytes_and_metadata to avoid changing
base_model/quantized_model before accuracy evaluation.

Comment on lines +267 to +298
for node in graph.node:
if node.op_type == "QuantizeLinear":
qdq_count += 1
elif node.op_type == "DequantizeLinear":
dq_count += 1
elif node.op_type == "MatMul":
matmul_count += 1
# Check if inputs come from DequantizeLinear
inputs_from_dq = []
for inp in node.input:
if inp in output_to_node:
parent = output_to_node[inp]
if parent.op_type == "DequantizeLinear":
inputs_from_dq.append(inp)
if len(inputs_from_dq) > 0:
matmul_with_qdq += 1
print(f" MatMul '{node.name}' has {len(inputs_from_dq)} QDQ input(s)")

print("\n Summary:")
print(f" QuantizeLinear nodes: {qdq_count}")
print(f" DequantizeLinear nodes: {dq_count}")
print(f" Total MatMul nodes: {matmul_count}")
print(f" MatMul with QDQ inputs: {matmul_with_qdq}")

# Check attention-specific MatMuls (Q@K^T and attn@V)
# In ViT-base with 12 layers, we expect 24 attention MatMuls (2 per layer)
# plus linear projection MatMuls
print("\n For ViT attention:")
print(" Expected attention MatMuls with QDQ: ~24 (12 layers x 2 BMMs)")
print(f" Found MatMuls with QDQ inputs: {matmul_with_qdq}")

return qdq_count, dq_count, matmul_count, matmul_with_qdq
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

This QDQ coverage check is not attention-specific.

matmul_with_qdq increments for any MatMul that consumes a DequantizeLinear output, including plain projection / MLP matmuls after weight compression. That means the “~24 attention MatMuls” check can still look healthy even when Q/K/V or softmax QDQ placement is wrong. Please key this off the actual attention pattern instead of all MatMul + DQ pairs.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/vit_mha_quantization.py` around lines 267 - 298, The
current loop counts any MatMul consuming a DequantizeLinear output
(matmul_with_qdq) which falsely includes projection/MLP matmuls; replace this
with an attention-specific check: implement a helper (e.g.,
is_attention_matmul(node, output_to_node, graph)) and only increment
matmul_with_qdq when that returns true. Make is_attention_matmul examine the
MatMul node name and upstream pattern (check parent ops via output_to_node for
Transpose/Reshape, Softmax, or names containing tokens like "q", "k", "v",
"attn", "score", "softmax") or detect the Q@K^T pattern by verifying one input
path comes from a Transpose of a Q-like tensor and the other from K-like tensor;
for attn@V detect the MatMul consuming Softmax output and a V-like source.
Update the loop that currently inspects node.op_type == "MatMul" and uses
inputs_from_dq to call this helper and only count/print when both QDQ and
attention pattern match.

Comment on lines +100 to +108
if any(c.op == "MatMul" for c in candidate.outputs[0].outputs):
perm = candidate.attrs.get("perm", None)
torch_weights = (
torch_weights.permute(*perm).contiguous()
if perm is not None
else torch_weights.T.contiguous()
)
transpose_to_remove = candidate
break
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Restrict these rewrites to pure MatMul fanout.

Each transform fires when it sees any MatMul consumer, then rewires all consumers and clears the original outputs. If one of these tensors also feeds a non-MatMul branch, that branch will now observe the untransposed / unscaled / quantized value instead of the original one. Please either require all(...) consumers to match the intended pattern or only rewrite the specific MatMul edges.

Also applies to: 128-137, 196-220, 255-275, 334-339

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 100 - 108, The code
currently uses any(c.op == "MatMul" for c in candidate.outputs[0].outputs) and
then rewires/clears all consumers which breaks non-MatMul branches; change the
logic to require all(c.op == "MatMul" for c in candidate.outputs[0].outputs)
before performing the global rewrite OR, preferably, only rewrite the specific
MatMul edges: iterate candidate.outputs[0].outputs, for each consumer c with
c.op == "MatMul" rewire that consumer's input to use the
transposed/scaled/quantized tensor and leave other consumers untouched, and do
not clear original outputs (update transpose_to_remove only when all downstream
edges have been safely redirected). Apply the same fix pattern to the other
rewrite sites that manipulate torch_weights, perm, transpose_to_remove, and
similar MatMul-aware transforms.

Comment thread modelopt/onnx/utils.py
Comment on lines +1422 to +1505
def _scale_fp32_to_fp16(scale_init: onnx.TensorProto) -> None:
"""Convert a scalar Q/DQ scale initializer in-place from FP32 to FP16."""
import numpy as np

if scale_init.data_type != onnx.TensorProto.FLOAT:
return
scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32)
if not scale_data.size:
scale_data = np.array(scale_init.float_data, dtype=np.float32)
scale_init.data_type = onnx.TensorProto.FLOAT16
scale_init.raw_data = scale_data.astype(np.float16).tobytes()
del scale_init.float_data[:]


def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Remove ``DQ → Cast(FP32→FP16)`` patterns inserted by ``convert_float_to_float16``.

The DQ scale is rewritten to FP16 so DQ natively produces FP16 output, enabling
TRT to fuse DQ directly into the downstream compute op.
"""
producer_map = {out: node for node in onnx_model.graph.node for out in node.output}
initializers = {init.name: init for init in onnx_model.graph.initializer}

to_remove = []
for node in onnx_model.graph.node:
if node.op_type != "Cast":
continue
cast_to = next((a.i for a in node.attribute if a.name == "to"), None)
if cast_to != onnx.TensorProto.FLOAT16:
continue
producer = producer_map.get(node.input[0])
if producer is None or producer.op_type not in _DQ_OPS:
continue

if len(producer.input) >= 2 and producer.input[1] in initializers:
_scale_fp32_to_fp16(initializers[producer.input[1]])

_bypass_cast_node(onnx_model, node)
to_remove.append(node)

for vi in onnx_model.graph.value_info:
if vi.name == producer.output[0]:
vi.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
break

logger.debug(f"Folded {len(to_remove)} DQ -> Cast(FP32->FP16) patterns")
for node in to_remove:
onnx_model.graph.node.remove(node)
return onnx_model


def fold_q_fp16_to_fp32_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Remove ``Cast(FP16→FP32) → Q`` patterns inserted by ``convert_float_to_float16``.

The Q scale is rewritten to FP16 so Q consumes the FP16 graph directly.
"""
consumer_map: dict[str, list[onnx.NodeProto]] = {}
for node in onnx_model.graph.node:
for inp in node.input:
consumer_map.setdefault(inp, []).append(node)
initializers = {init.name: init for init in onnx_model.graph.initializer}

to_remove = []
for node in onnx_model.graph.node:
if node.op_type != "Cast":
continue
cast_to = next((a.i for a in node.attribute if a.name == "to"), None)
if cast_to != onnx.TensorProto.FLOAT:
continue
consumers = consumer_map.get(node.output[0], [])
if not consumers or not all(c.op_type in _Q_OPS for c in consumers):
continue

for q_node in consumers:
if len(q_node.input) >= 2 and q_node.input[1] in initializers:
_scale_fp32_to_fp16(initializers[q_node.input[1]])

_bypass_cast_node(onnx_model, node)
to_remove.append(node)

logger.debug(f"Folded {len(to_remove)} Cast(FP16->FP32) -> Q patterns")
for node in to_remove:
onnx_model.graph.node.remove(node)
return onnx_model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

What is the minimum ONNX opset version that allows FLOAT16 scale tensors for QuantizeLinear and DequantizeLinear?

💡 Result:

The minimum ONNX opset version that allows FLOAT16 scale tensors for both QuantizeLinear and DequantizeLinear is 19.

Citations:


🏁 Script executed:

# Check for BASE_MIN_OPSET and get_opset_version in the file
rg -n "BASE_MIN_OPSET|get_opset_version" modelopt/onnx/utils.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 335


🏁 Script executed:

# Check the actual implementation of fold_dq_fp32_to_fp16_casts and fold_q_fp16_to_fp32_casts for guards
sed -n '1422,1505p' modelopt/onnx/utils.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 4025


🏁 Script executed:

# Search for where these fold functions are called
rg -n "fold_dq_fp32_to_fp16_casts|fold_q_fp16_to_fp32_casts" modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 234


🏁 Script executed:

# Look for opset version checks near the fold function calls
rg -B 10 -A 10 "fold_dq_fp32_to_fp16_casts|fold_q_fp16_to_fp32_casts" modelopt/onnx/utils.py | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 1836


🏁 Script executed:

# Search for calls to these fold functions throughout the codebase
rg -n "fold_dq_fp32_to_fp16_casts|fold_q_fp16_to_fp32_casts" modelopt/ --type py -B 5 -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 3546


🏁 Script executed:

# Also check for the convert_float_to_float16 function mentioned in the docstring
rg -n "convert_float_to_float16" modelopt/onnx/utils.py -B 5 -A 20

Repository: NVIDIA/Model-Optimizer

Length of output: 2454


🏁 Script executed:

# Check the context where fold functions are called in torch_onnx.py
sed -n '610,640p' modelopt/torch/_deploy/utils/torch_onnx.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 1714


🏁 Script executed:

# Look for opset validation in the export_to_onnx function or its caller
grep -n "def export_to_onnx\|opset\|BASE_MIN_OPSET" modelopt/torch/_deploy/utils/torch_onnx.py | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 232


🏁 Script executed:

# Check what the is_fp8_quantized function does
rg -n "def is_fp8_quantized" modelopt/ -A 10 --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 1147


🏁 Script executed:

# Check the convert_float_to_float16 function and how it's used
rg -n "def convert_float_to_float16\|convert_float_to_float16" modelopt/onnx/utils.py -B 3 -A 15 | head -60

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check for any opset validation in export_to_onnx before the fold calls
sed -n '450,630p' modelopt/torch/_deploy/utils/torch_onnx.py | grep -n "opset\|get_opset_version\|BASE_MIN_OPSET"

Repository: NVIDIA/Model-Optimizer

Length of output: 230


🏁 Script executed:

# Look at DEFAULT_ONNX_OPSET to understand the default opset value
rg -n "DEFAULT_ONNX_OPSET" modelopt/ --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 459


🏁 Script executed:

# Search for convert_float_to_float16 function across the entire codebase
rg -n "convert_float_to_float16" modelopt/ --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 506


🏁 Script executed:

# Check what the default export behavior does - look at complete export_to_onnx signature and opset handling
sed -n '450,580p' modelopt/torch/_deploy/utils/torch_onnx.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 6658


🏁 Script executed:

# Check the entire export_to_onnx function flow to see if there's any opset validation before fold calls
sed -n '460,650p' modelopt/torch/_deploy/utils/torch_onnx.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 9684


🏁 Script executed:

# Search for any opset validation or assertion related to BASE_MIN_OPSET in the export path
rg -n "BASE_MIN_OPSET" modelopt/torch/_deploy/ --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


Add opset guard to FP16 scale rewrite helpers.

These helpers rewrite Q/DQ scale initializers to FLOAT16 unconditionally, but FLOAT16 scales are only valid from opset 19 onward. When FP8 quantization is enabled, a user can request a lower opset (e.g., onnx_opset=18) via the export API, which will cause these functions to silently produce an invalid model. Check get_opset_version(onnx_model) >= BASE_MIN_OPSET before mutating initializers, or skip the fold operation entirely for lower opsets.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1422 - 1505, The fold helpers
unconditionally convert Q/DQ scale initializers to FLOAT16 which is invalid for
opsets < BASE_MIN_OPSET; update _scale_fp32_to_fp16, fold_dq_fp32_to_fp16_casts
and fold_q_fp16_to_fp32_casts to guard the mutation by checking
get_opset_version(onnx_model) (or the model passed in) and only perform the
FP32→FP16 rewrite when get_opset_version(onnx_model) >= BASE_MIN_OPSET; if the
check fails, skip mutating initializers and skip folding the cast nodes (i.e.,
return the model unchanged or continue without calling
_scale_fp32_to_fp16/_bypass_cast_node), using the existing function names
(_scale_fp32_to_fp16, fold_dq_fp32_to_fp16_casts, fold_q_fp16_to_fp32_casts) and
constants (BASE_MIN_OPSET) to locate where to add the guard.

Comment on lines +292 to +300
parent_attention_types = {
_QuantAttention.get_attn_type(module)
for _, module in model.named_modules()
if type(module).__name__.endswith("Attention")
and any(
child is not module and type(child).__name__.endswith("Attention")
for _, child in module.named_modules()
)
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Skip nested attention wrappers by instance, not by class.

Once any instance of an attention class lands in parent_attention_types, every module of that class is skipped here. That will miss quantization for models that reuse the same attention class both as a wrapper and as a leaf. Checking module.named_modules() in-loop avoids dropping unrelated instances.

Also applies to: 306-307

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds FP8 MHA quantization support for HuggingFace ViT models with ONNX export optimizations. The implementation is well-structured and addresses a real gap (NVBug 6078291). However, there are several issues to address:

Critical issues:

  1. No unit tests — This is ~933 lines of new/changed library code across core export paths (fp8_exporter.py, utils.py, export_onnx.py, huggingface.py) with zero unit tests. The only "test" is the example script which requires GPU, ImageNet data, and TRT. The graph rewrite functions in fp8_exporter.py, the cast folding helpers in utils.py, the attention skipping logic in huggingface.py, and the LayerNorm quantization registration all need unit tests.

  2. Bare assert for runtime validation in fp8_exporter.py — the existing assert on QDQ pair validation will be stripped with -O.

  3. Silent contextlib.suppress(Exception) in the example — can mask real failures during benchmark parsing.

Minor issues:
4. The _scale_fp32_to_fp16 helper doesn't handle the case where the scale value overflows or underflows to inf/0 in FP16 — this could silently produce bad quantization results for extreme scales.

  1. The _move_mul_before_qdq rewrite assumes a single scalar const Mul for attention scaling; if the model architecture changes, these pattern-matching rewrites could silently become no-ops without any warning.

  2. The _insert_qdq_after_softmax hardcodes scale=1/448 which is correct for E4M3 but should at minimum document why this specific value and that it's tied to the FP8 E4M3 max representable value.

Positive aspects:

  • Clean separation of graph rewrites as static methods
  • Good docstrings on the new functions
  • The parent_attention_types detection for avoiding double-patching is well done
  • The LayerNorm registration follows existing patterns exactly

@@ -78,6 +80,33 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
f"QDQ does not occur in pairs. You reached {dq_op.op}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug (existing but expanded): This assert will be stripped with -O flag. Per codebase conventions, use raise ValueError(...) or RuntimeError(...) instead for runtime validation:

if dq_op.op != "TRT_FP8DequantizeLinear":
    raise RuntimeError(f"QDQ does not occur in pairs. You reached {dq_op.op}")

if candidate.op != "Transpose":
cast_to_remove = None
continue
if any(c.op == "MatMul" for c in candidate.outputs[0].outputs):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential issue: When dq_op.outputs[0] has multiple consumers and one is Cast followed by Transpose, the loop breaks after the first candidate. If the first consumer is a Cast that doesn't lead to a Transpose, cast_to_remove is set to None and then continue re-enters the loop at the next candidate — but break at line 109 exits regardless. This means if the first candidate is a non-Cast/non-Transpose node, we skip checking remaining consumers. Consider restructuring to check all candidates:

for candidate in list(dq_op.outputs[0].outputs):
    if candidate.op == "Cast":
        ...
        if candidate is None:
            cast_to_remove = None
            continue  # <-- this continues to next candidate, good
    if candidate.op != "Transpose":
        cast_to_remove = None
        continue  # <-- this also continues, good  
    ...
    break  # <-- but this break exits even if we didn't find a match

Actually looking more carefully, the logic resets cast_to_remove = None when it's not a Transpose, then continues — so it does check the next candidate. But if candidate.op is a Transpose without a MatMul consumer, you break without resetting state. The logic works but is fragile. A comment would help.

gs.Node(
op="Transpose",
name=transpose_node.name + "_moved",
inputs=[q_input],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded 1.0/448.0 is correct (FP8 E4M3 max = 448, softmax output ∈ [0,1]), and the docstring explains this well. However, consider defining this as a named constant (e.g., _FP8_E4M3_SOFTMAX_SCALE = 1.0 / 448.0) since the same value is also used in export_onnx.py (export_fp8_mha uses 1.0 for softmax amax which becomes 448.0/1.0 then 1.0/448.0). This would make the connection explicit.

Comment thread modelopt/onnx/utils.py
return
scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32)
if not scale_data.size:
scale_data = np.array(scale_init.float_data, dtype=np.float32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential precision issue: _scale_fp32_to_fp16 does not check if the FP32 scale value is outside FP16 representable range. If a scale is very small (< ~5.96e-8 for FP16 subnormals) or very large (> 65504), the conversion will silently produce 0 or inf. Consider adding a warning:

import warnings
fp16_val = scale_data.astype(np.float16)
if np.any(np.isinf(fp16_val)) or np.any(fp16_val == 0) and np.any(scale_data != 0):
    logger.warning("Q/DQ scale overflows/underflows when converting to FP16")

Comment thread modelopt/onnx/utils.py
def _scale_fp32_to_fp16(scale_init: onnx.TensorProto) -> None:
"""Convert a scalar Q/DQ scale initializer in-place from FP32 to FP16."""
import numpy as np

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: np is already imported at the top of the file. The local import numpy as np is unnecessary.

)
scale = g.op(
"Constant",
value_t=torch.tensor(scale_inv, dtype=torch_dtype_map[otype]), # type: ignore[index]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above: trt_high_precision_dtype is now unused in _fp8_dequantize — the DQ output type is set to otype directly. The parameter should either be removed or have a deprecation comment explaining it's kept for API compatibility.

parent_attention_types = {
_QuantAttention.get_attn_type(module)
for _, module in model.named_modules()
if type(module).__name__.endswith("Attention")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good approach for detecting parent attention wrappers! One edge case: named_modules() includes the module itself, so child is not module is the correct check (which you have). However, this only checks one level of nesting. If there's a 3-level hierarchy like OuterAttention > MiddleAttention > InnerAttention, both OuterAttention and MiddleAttention would be in parent_attention_types, and only InnerAttention would be registered. Is that the intended behavior? For ViT it's fine (only 2 levels), but worth a comment.


from .quant_module import QuantInputBase, QuantModuleRegistry

QuantModuleRegistry.register({nn.LayerNorm: "nn.LayerNorm"})(QuantInputBase)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This follows the exact pattern of quant_batchnorm.py — good. However, unlike BatchNorm which has 1d/2d/3d variants, LayerNorm only has one class, so this is clean. Consider adding a brief module-level docstring noting that this enables LayerNorm output quantizers to be honored during quantization (the file docstring says "Quantized layer normalization module" which is a bit terse).

@@ -0,0 +1,538 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing unit tests: This 538-line example is the only "test" for the entire PR. The core library changes (graph rewrites in fp8_exporter.py, cast folding in utils.py, attention skipping in huggingface.py, LayerNorm quantization) need proper unit tests that:

  1. Test _move_mul_before_qdq with a synthetic ONNX graph
  2. Test _move_transpose_before_qdq with a synthetic ONNX graph
  3. Test _insert_qdq_after_softmax with a synthetic ONNX graph
  4. Test fold_dq_fp32_to_fp16_casts and fold_q_fp16_to_fp32_casts
  5. Test parent_attention_types detection with a mock ViT-like model
  6. Test that nn.LayerNorm is properly registered in QuantModuleRegistry

These can all be CPU-only tests with small synthetic models/graphs.



# ──────────────────────────────────────────────────────────────────────────────
# Step 4: TRT latency benchmarking
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contextlib.suppress(Exception) is used in multiple places in the benchmark parsing. Per codebase conventions, prefer explicit handling — suppressing all exceptions can mask real bugs. At minimum, suppress only (IndexError, ValueError) which is what you're already doing inline:

with contextlib.suppress(IndexError, ValueError):

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 17, 2026

Codecov Report

❌ Patch coverage is 9.58904% with 198 lines in your changes missing coverage. Please review.
✅ Project coverage is 71.05%. Comparing base (07ae8e7) to head (d6533ac).
⚠️ Report is 12 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/export/fp8_exporter.py 4.13% 139 Missing ⚠️
modelopt/onnx/utils.py 8.06% 57 Missing ⚠️
modelopt/torch/_deploy/utils/torch_onnx.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1289      +/-   ##
==========================================
- Coverage   75.61%   71.05%   -4.57%     
==========================================
  Files         459      460       +1     
  Lines       48597    49312     +715     
==========================================
- Hits        36747    35039    -1708     
- Misses      11850    14273    +2423     
Flag Coverage Δ
examples 38.57% <9.13%> (+8.73%) ⬆️
gpu 51.53% <9.13%> (-8.98%) ⬇️
unit 52.00% <8.67%> (-0.24%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update changelog for 0.45 to mention this new feature details?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants