From 9bfcb7270ac573381ca3824e8ae1f391c3265331 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Fri, 17 Apr 2026 20:13:23 +0000 Subject: [PATCH] Add FP8 MHA quantization support for HuggingFace ViT MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar transformer vision models) when exported to ONNX with FP8 Q/DQ. - fp8_exporter: rewrite attention-scaling Mul and K Transpose to the Q-side so DQ feeds MatMul directly, pre-transpose weight constants, insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype now matches the graph's float dtype to keep strongly-typed builds consistent. - onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16 inserts around Q/DQ by rewriting scale initializers to FP16, so TRT fuses DQ into the downstream GEMM/MatMul kernel. - torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native input dtype so no Cast is injected between graph and Q/DQ. - torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry so LayerNorm output quantizers are honored. - torch/quantization/plugins/huggingface: skip attention wrappers whose children are also "*Attention" to avoid double-patching eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention). Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8 config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer, disabled input quantizers on LayerNorm-followed layers, and *_bmm_quantizer entries) plus accuracy + TRT-latency comparison against an FP16 baseline. Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1): - Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs 80.96% / 95.44% (torch FP8) — -0.20% / -0.06% - TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- CHANGELOG.rst | 1 + examples/torch_onnx/vit_mha_quantization.py | 591 ++++++++++++++++++ modelopt/onnx/export/fp8_exporter.py | 290 ++++++++- modelopt/onnx/utils.py | 82 ++- modelopt/torch/_deploy/utils/torch_onnx.py | 6 + modelopt/torch/quantization/export_onnx.py | 45 +- modelopt/torch/quantization/nn/__init__.py | 1 + .../nn/modules/quant_layernorm.py | 25 + .../torch/quantization/plugins/huggingface.py | 15 + 9 files changed, 1015 insertions(+), 41 deletions(-) create mode 100644 examples/torch_onnx/vit_mha_quantization.py create mode 100644 modelopt/torch/quantization/nn/modules/quant_layernorm.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 20a677d0a0..22a359e6d8 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,7 @@ Changelog - [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution. - Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml `_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml `_ for usage. - Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.quantization.src.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning. +- Add FP8 MHA quantization for HuggingFace ViT. Adds BMM input quantizers (``*_bmm_quantizer``) and LayerNorm output quantizer to FP8_DEFAULT_CFG so TRT can fuse FP8 attention kernels, an attention-aware ONNX post-processing pass (scale Mul / K-transpose move before Q, Q→DQ insertion on softmax output) in :class:`FP8QuantExporter `, per-instance nested-attention-wrapper skipping in the HF plugin, and ``nn.LayerNorm`` registration in ``QuantModuleRegistry``. See `examples/torch_onnx/vit_mha_quantization.py `_ for the end-to-end workflow. **Backward Breaking Changes** diff --git a/examples/torch_onnx/vit_mha_quantization.py b/examples/torch_onnx/vit_mha_quantization.py new file mode 100644 index 0000000000..69d47c1703 --- /dev/null +++ b/examples/torch_onnx/vit_mha_quantization.py @@ -0,0 +1,591 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ViT FP8 MHA quantization: quantize, export, verify QDQ nodes, evaluate accuracy & latency. + +This script addresses NVBug 6078291: FP8 quantization config does not quantize MHA properly +for HuggingFace ViT. The fix adds q_bmm_quantizer, k_bmm_quantizer, v_bmm_quantizer to +FP8_DEFAULT_CFG so that attention BMM operations are FP8 quantized. + +The script: +1. Quantizes HF ViT with the updated FP8 config (torch path) +2. Exports to ONNX and verifies QDQ nodes are inserted before bmm1/bmm2 +3. Evaluates accuracy on ImageNet validation set +4. Quantizes with ONNX PTQ for comparison +5. Benchmarks TRT latency for all three models (FP16 baseline, torch FP8, ONNX PTQ FP8) +""" + +import argparse +import copy +import os +import subprocess +import sys +import time +from pathlib import Path + +# Add onnx_ptq to path for shared modules +sys.path.insert(0, str(Path(__file__).parent.parent / "onnx_ptq")) + +import onnx +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import ViTForImageClassification, ViTImageProcessor + +import modelopt.torch.quantization as mtq +from modelopt.torch._deploy.utils import OnnxBytes, get_onnx_bytes_and_metadata + +# ────────────────────────────────────────────────────────────────────────────── +# Helpers +# ────────────────────────────────────────────────────────────────────────────── + + +def load_imagenet_val(processor, num_examples=None, batch_size=32, streaming=False): + """Load ImageNet validation set with HF ViT processor transforms. + + Set ``streaming=True`` to avoid touching the local HF datasets cache (useful when the + cache dir is not writable). Streaming returns an ``IterableDataset`` and does not + support ``len()`` / random access. + """ + dataset = load_dataset( + "ILSVRC/imagenet-1k", + split="validation", + data_files={"validation": "data/validation*"}, + verification_mode="no_checks", + streaming=streaming, + ) + + def _preprocess(item): + image = item["image"] + if image.mode != "RGB": + image = image.convert("RGB") + return processor(images=image, return_tensors="pt")["pixel_values"].squeeze(0), item[ + "label" + ] + + if streaming: + class StreamingImageNet(torch.utils.data.IterableDataset): + def __init__(self, hf_ds): + self.hf_ds = hf_ds + + def __iter__(self): + for item in self.hf_ds: + yield _preprocess(item) + + return torch.utils.data.DataLoader( + StreamingImageNet(dataset), batch_size=batch_size, num_workers=0 + ) + + class ImageNetDataset(torch.utils.data.Dataset): + def __init__(self, hf_dataset): + self.dataset = hf_dataset + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + return _preprocess(self.dataset[idx]) + + return torch.utils.data.DataLoader( + ImageNetDataset(dataset), batch_size=batch_size, shuffle=False, num_workers=4 + ) + + +def evaluate_accuracy(model, val_loader, num_examples=None, device="cuda"): + """Evaluate top-1 and top-5 accuracy.""" + model.eval() + if isinstance(model, torch.nn.Module): + model = model.to(device) + + total = 0 + correct_top1 = 0 + correct_top5 = 0 + + if num_examples is not None: + total_batches = num_examples // val_loader.batch_size + else: + try: + total_batches = len(val_loader) + except TypeError: + total_batches = None # streaming + + with torch.no_grad(): + for images, labels in tqdm(val_loader, total=total_batches, desc="Evaluating"): + if num_examples and total >= num_examples: + break + images = images.to(device) + labels = labels.to(device) + + outputs = model(images) + logits = outputs.logits if hasattr(outputs, "logits") else outputs + + _, top5_pred = torch.topk(logits, 5, dim=1) + correct_top1 += (top5_pred[:, 0] == labels).sum().item() + correct_top5 += (top5_pred == labels.unsqueeze(1)).any(dim=1).sum().item() + total += labels.size(0) + + top1 = 100.0 * correct_top1 / total + top5 = 100.0 * correct_top5 / total + return top1, top5 + + +def load_calibration_data(processor, num_samples=512, batch_size=32, device="cuda"): + """Load calibration data from tiny-imagenet.""" + dataset = load_dataset("zh-plus/tiny-imagenet") + images = dataset["train"][:num_samples]["image"] + + calib_tensors = [] + for img in images: + if img.mode != "RGB": + img = img.convert("RGB") + inputs = processor(images=img, return_tensors="pt") + calib_tensors.append(inputs["pixel_values"].squeeze(0)) + + return torch.utils.data.DataLoader( + calib_tensors, batch_size=batch_size, shuffle=True, num_workers=4 + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# Step 1: Inspect model and quantize with torch FP8 +# ────────────────────────────────────────────────────────────────────────────── + + +def inspect_model(model): + """Print attention layer names and identify bmm1/bmm2 related modules.""" + print("\n" + "=" * 80) + print("ATTENTION LAYERS IN ViT MODEL") + print("=" * 80) + for name, module in model.named_modules(): + if type(module).__name__.endswith("Attention"): + print(f" {name}: {type(module).__name__}") + + +def _vit_fp8_config(): + """FP8 quant config tailored for HF ViT. + + Extends FP8_DEFAULT_CFG with: + - LayerNorm output quantizer (share one QDQ across all downstream Q/K/V/FC consumers + instead of repeating it on each input). + - MHA BMM-input quantizers (``*_bmm_quantizer``) so TRT can fuse attention as FP8. + """ + fp8_cfg = {"num_bits": (4, 3), "axis": None} + config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + config["quant_cfg"].extend( + [ + {"parent_class": "nn.LayerNorm", "quantizer_name": "*output_quantizer", "cfg": fp8_cfg}, + {"parent_class": "nn.LayerNorm", "quantizer_name": "*input_quantizer", "enable": False}, + {"quantizer_name": "*query.input_quantizer", "enable": False}, + {"quantizer_name": "*key.input_quantizer", "enable": False}, + {"quantizer_name": "*value.input_quantizer", "enable": False}, + {"quantizer_name": "*intermediate.dense.input_quantizer", "enable": False}, + {"quantizer_name": "*_bmm_quantizer", "cfg": fp8_cfg}, + ] + ) + return config + + +def quantize_vit_fp8(model, processor, device="cuda"): + """Quantize ViT with MHA-aware FP8 config.""" + calib_loader = load_calibration_data(processor, num_samples=512, batch_size=32, device=device) + + def forward_loop(model): + for batch in calib_loader: + model(batch.to(device)) + + print("\n" + "=" * 80) + print("QUANTIZING ViT WITH ViT-FP8 config (extends FP8_DEFAULT_CFG with MHA quantizers)") + print("=" * 80) + + quantized_model = mtq.quantize(model, _vit_fp8_config(), forward_loop=forward_loop) + + # Disable quantization on patch embedding and downsample layers (unsupported by TRT ONNX export) + import re + + def filter_func(name): + pattern = re.compile(r".*(patch_embed|patch_embeddings|downsample|pos_embed).*") + return pattern.match(name) is not None + + mtq.disable_quantizer(quantized_model, filter_func) + + # Print bmm quantizer status + print("\nBMM Quantizer Status:") + for name, module in quantized_model.named_modules(): + if "bmm" in name.lower() or "softmax_quantizer" in name.lower(): + enabled = getattr(module, "is_enabled", None) + num_bits = getattr(module, "num_bits", None) + amax = getattr(module, "_amax", None) + amax_val = amax.item() if amax is not None and amax.numel() == 1 else amax + print(f" {name}: enabled={enabled}, num_bits={num_bits}, amax={amax_val}") + + mtq.print_quant_summary(quantized_model) + return quantized_model + + +# ────────────────────────────────────────────────────────────────────────────── +# Step 2: Export to ONNX and verify QDQ nodes +# ────────────────────────────────────────────────────────────────────────────── + + +def _export_onnx(model, onnx_path, device="cuda", weights_dtype="fp16"): + """Export model to ONNX using get_onnx_bytes_and_metadata (handles QDQ correctly).""" + print(f"\nExporting to ONNX: {onnx_path}") + dummy_input = (torch.randn(1, 3, 224, 224, dtype=torch.float32).to(device),) + model_name = os.path.basename(onnx_path).replace(".onnx", "") + + # .float() mutates in-place; deep-copy so downstream accuracy eval sees the original dtype. + export_model = copy.deepcopy(model).float() + onnx_bytes, _ = get_onnx_bytes_and_metadata( + model=export_model, + dummy_input=dummy_input, + weights_dtype=weights_dtype, + model_name=model_name, + ) + onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes) + onnx_bytes_obj.write_to_disk(os.path.dirname(onnx_path), clean_dir=False) + print(f" Exported successfully: {onnx_path}") + return onnx_path + + +def export_to_onnx(model, onnx_path, device="cuda"): + """Export quantized model to ONNX with QDQ nodes (FP16 graph + FP8 QDQ for TRT e4m3 fusion).""" + return _export_onnx(model, onnx_path, device=device, weights_dtype="fp16") + + +def export_fp16_baseline(model, onnx_path, device="cuda"): + """Export FP16 baseline model to ONNX (no quantization).""" + return _export_onnx(model, onnx_path, device=device, weights_dtype="fp16") + + +def _is_attention_matmul(node, output_to_node, depth=4): + """Heuristic: attention BMMs are MatMuls with Softmax (attn@V) or Transpose-K (Q@K^T) ancestors. + + Projection/MLP matmuls typically consume DQ directly from a weight constant, not from + Softmax or a transpose on the right-hand side, so walking a few ops back disambiguates. + """ + if node.op_type != "MatMul": + return False + + def _walk(name, steps): + if steps == 0 or name not in output_to_node: + return None + parent = output_to_node[name] + if parent.op_type in {"Softmax", "Transpose"}: + return parent.op_type + # Skip through dtype/shape-preserving ops that wrap attention tensors. + if parent.op_type in {"DequantizeLinear", "QuantizeLinear", "Cast", "Reshape", "Mul"}: + return _walk(parent.input[0], steps - 1) if parent.input else None + return None + + return any(_walk(inp, depth) for inp in node.input) + + +def check_qdq_nodes(onnx_path): + """Check if QDQ nodes are inserted before bmm1 and bmm2 in attention.""" + print(f"\n{'=' * 80}") + print(f"CHECKING QDQ NODES IN: {onnx_path}") + print(f"{'=' * 80}") + + model = onnx.load(onnx_path) + graph = model.graph + + # Build output->node map + output_to_node = {} + for node in graph.node: + for output in node.output: + output_to_node[output] = node + + qdq_count = 0 + dq_count = 0 + matmul_count = 0 + attn_matmul_with_qdq = 0 + + for node in graph.node: + if node.op_type == "QuantizeLinear": + qdq_count += 1 + elif node.op_type == "DequantizeLinear": + dq_count += 1 + elif node.op_type == "MatMul": + matmul_count += 1 + inputs_from_dq = [ + inp + for inp in node.input + if inp in output_to_node and output_to_node[inp].op_type == "DequantizeLinear" + ] + if inputs_from_dq and _is_attention_matmul(node, output_to_node): + attn_matmul_with_qdq += 1 + print(f" Attn MatMul '{node.name}' has {len(inputs_from_dq)} QDQ input(s)") + + print("\n Summary:") + print(f" QuantizeLinear nodes: {qdq_count}") + print(f" DequantizeLinear nodes: {dq_count}") + print(f" Total MatMul nodes: {matmul_count}") + print(f" Attention MatMul w/ QDQ: {attn_matmul_with_qdq}") + + # Check attention-specific MatMuls (Q@K^T and attn@V) + # In ViT-base with 12 layers, we expect 24 attention MatMuls (2 per layer) + print("\n For ViT attention:") + print(" Expected attention MatMuls with QDQ: ~24 (12 layers x 2 BMMs)") + print(f" Found attention MatMuls with QDQ: {attn_matmul_with_qdq}") + + return qdq_count, dq_count, matmul_count, attn_matmul_with_qdq + + +# ────────────────────────────────────────────────────────────────────────────── +# Step 3: ONNX PTQ quantization +# ────────────────────────────────────────────────────────────────────────────── + + +def quantize_onnx_ptq(fp16_onnx_path, output_path): + """Quantize ONNX model using ONNX PTQ FP8.""" + print(f"\n{'=' * 80}") + print("ONNX PTQ FP8 QUANTIZATION") + print(f"{'=' * 80}") + + import modelopt.onnx.quantization as moq + + moq.quantize( + onnx_path=fp16_onnx_path, + quantize_mode="fp8", + output_path=output_path, + ) + print(f" ONNX PTQ FP8 model saved to: {output_path}") + return output_path + + +# ────────────────────────────────────────────────────────────────────────────── +# Step 4: TRT latency benchmarking +# ────────────────────────────────────────────────────────────────────────────── + + +def build_trt_engine(onnx_path, engine_path, fp8=False): + """Build a TRT engine from ONNX model.""" + cmd = [ + "trtexec", + f"--onnx={onnx_path}", + f"--saveEngine={engine_path}", + ] + if fp8: + cmd.extend(["--fp8", "--stronglyTyped"]) + else: + cmd.append("--fp16") + + print(f"\n Building TRT engine: {engine_path}") + print(f" Command: {' '.join(cmd)}") + result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) + if result.returncode != 0: + print(f" ERROR building engine:\n{result.stderr[-2000:]}") + return None + print(" Engine built successfully") + return engine_path + + +def benchmark_trt(engine_path, warmup=50, iterations=200): + """Benchmark TRT engine latency using trtexec.""" + cmd = [ + "trtexec", + f"--loadEngine={engine_path}", + f"--warmUp={warmup * 10}", # trtexec warmup is in ms + f"--iterations={iterations}", + ] + print(f"\n Benchmarking: {engine_path}") + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + if result.returncode != 0: + print(f" ERROR benchmarking:\n{result.stderr[-2000:]}") + return None + + import contextlib + + # Parse trtexec output for latency + output = result.stdout + result.stderr + latency = None + throughput = None + for line in output.split("\n"): + if "mean" in line.lower() and "ms" in line.lower() and "GPU" in line: + # Extract mean GPU latency + parts = line.split() + for i, part in enumerate(parts): + if part == "mean": + with contextlib.suppress(IndexError, ValueError): + latency = float(parts[i + 2]) + if "Throughput" in line: + parts = line.split() + for i, part in enumerate(parts): + if "qps" in part.lower() or "infer" in part.lower(): + with contextlib.suppress(IndexError, ValueError): + throughput = float(parts[i - 1]) + + # Fallback: parse the percentile lines + for line in output.split("\n"): + if "GPU Compute Time" in line and "mean" in line: + # e.g., "GPU Compute Time: min = 0.123 ms, max = ..., mean = 0.456 ms, ..." + parts = line.split("mean = ") + if len(parts) > 1: + with contextlib.suppress(ValueError, IndexError): + latency = float(parts[1].split(" ms")[0]) + if "Throughput:" in line: + with contextlib.suppress(ValueError, IndexError): + throughput = float(line.split(":")[1].strip().split(" ")[0]) + + print(f" Mean GPU latency: {latency:.3f} ms" if latency else " Could not parse latency") + print(f" Throughput: {throughput:.1f} qps" if throughput else " Could not parse throughput") + return {"latency_ms": latency, "throughput_qps": throughput} + + +# ────────────────────────────────────────────────────────────────────────────── +# Main +# ────────────────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser(description="ViT FP8 MHA Quantization (NVBug 6078291)") + parser.add_argument("--model_name", default="google/vit-base-patch16-224", help="HF ViT model") + parser.add_argument("--output_dir", default="./vit_mha_results", help="Output directory") + parser.add_argument("--eval_samples", type=int, default=None, help="Eval samples (None=full)") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size for evaluation") + parser.add_argument("--skip_eval", action="store_true", help="Skip accuracy evaluation") + parser.add_argument("--skip_latency", action="store_true", help="Skip TRT latency benchmark") + parser.add_argument("--skip_onnx_ptq", action="store_true", help="Skip ONNX PTQ path") + parser.add_argument( + "--streaming_eval", + action="store_true", + help="Stream ImageNet val rather than using the local HF datasets cache (useful when the cache is read-only)", + ) + args = parser.parse_args() + + device = "cuda" if torch.cuda.is_available() else "cpu" + os.makedirs(args.output_dir, exist_ok=True) + + # Paths + fp16_onnx = os.path.join(args.output_dir, "vit_fp16_baseline.onnx") + torch_fp8_onnx = os.path.join(args.output_dir, "vit_torch_fp8.onnx") + onnx_ptq_fp8 = os.path.join(args.output_dir, "vit_onnx_ptq_fp8.onnx") + + # Load model and processor + print(f"Loading model: {args.model_name}") + processor = ViTImageProcessor.from_pretrained(args.model_name) + base_model = ViTForImageClassification.from_pretrained(args.model_name).eval().to(device) + + # ── Step 1: Inspect model ──────────────────────────────────────────────── + inspect_model(base_model) + + # ── Step 2: Export FP16 baseline ───────────────────────────────────────── + export_fp16_baseline(base_model, fp16_onnx, device=device) + + # ── Step 3: Quantize with torch FP8 (MHA-aware) ───────────────────────── + # Clone for quantization so we keep the original for baseline eval + model_for_quant = copy.deepcopy(base_model) + quantized_model = quantize_vit_fp8(model_for_quant, processor, device=device) + + # ── Step 4: Export quantized model to ONNX ─────────────────────────────── + export_to_onnx(quantized_model, torch_fp8_onnx, device=device) + + # ── Step 5: Check QDQ nodes ────────────────────────────────────────────── + print("\n--- FP16 Baseline ONNX ---") + check_qdq_nodes(fp16_onnx) + print("\n--- Torch FP8 ONNX ---") + check_qdq_nodes(torch_fp8_onnx) + + # ── Step 6: ONNX PTQ ──────────────────────────────────────────────────── + if not args.skip_onnx_ptq: + quantize_onnx_ptq(fp16_onnx, onnx_ptq_fp8) + print("\n--- ONNX PTQ FP8 ---") + check_qdq_nodes(onnx_ptq_fp8) + + # ── Step 7: Evaluate accuracy ──────────────────────────────────────────── + results = {} + if not args.skip_eval: + print(f"\n{'=' * 80}") + print("ACCURACY EVALUATION") + print(f"{'=' * 80}") + + val_loader = load_imagenet_val( + processor, + num_examples=args.eval_samples, + batch_size=args.batch_size, + streaming=args.streaming_eval, + ) + + print("\n--- FP16 Baseline ---") + t0 = time.time() + top1_base, top5_base = evaluate_accuracy( + base_model, val_loader, num_examples=args.eval_samples, device=device + ) + eval_time_base = time.time() - t0 + results["fp16_baseline"] = {"top1": top1_base, "top5": top5_base} + print(f" Top-1: {top1_base:.2f}%, Top-5: {top5_base:.2f}% (took {eval_time_base:.1f}s)") + + print("\n--- Torch FP8 (with MHA) ---") + t0 = time.time() + top1_fp8, top5_fp8 = evaluate_accuracy( + quantized_model, val_loader, num_examples=args.eval_samples, device=device + ) + eval_time_fp8 = time.time() - t0 + results["torch_fp8"] = {"top1": top1_fp8, "top5": top5_fp8} + print(f" Top-1: {top1_fp8:.2f}%, Top-5: {top5_fp8:.2f}% (took {eval_time_fp8:.1f}s)") + + # ── Step 8: TRT latency benchmark ──────────────────────────────────────── + if not args.skip_latency: + print(f"\n{'=' * 80}") + print("TRT LATENCY BENCHMARKING") + print(f"{'=' * 80}") + + # Build engines + fp16_engine = os.path.join(args.output_dir, "vit_fp16.engine") + torch_fp8_engine = os.path.join(args.output_dir, "vit_torch_fp8.engine") + + build_trt_engine(fp16_onnx, fp16_engine, fp8=False) + build_trt_engine(torch_fp8_onnx, torch_fp8_engine, fp8=True) + + latency_results = {} + + if os.path.exists(fp16_engine): + print("\n--- FP16 Baseline Latency ---") + latency_results["fp16_baseline"] = benchmark_trt(fp16_engine) + + if os.path.exists(torch_fp8_engine): + print("\n--- Torch FP8 Latency ---") + latency_results["torch_fp8"] = benchmark_trt(torch_fp8_engine) + + if not args.skip_onnx_ptq and os.path.exists(onnx_ptq_fp8): + onnx_ptq_engine = os.path.join(args.output_dir, "vit_onnx_ptq_fp8.engine") + build_trt_engine(onnx_ptq_fp8, onnx_ptq_engine, fp8=True) + if os.path.exists(onnx_ptq_engine): + print("\n--- ONNX PTQ FP8 Latency ---") + latency_results["onnx_ptq_fp8"] = benchmark_trt(onnx_ptq_engine) + + # ── Final Summary ──────────────────────────────────────────────────────── + print(f"\n{'=' * 80}") + print("FINAL COMPARISON") + print(f"{'=' * 80}") + print( + f"\n{'Model':<25} {'Top-1 %':<12} {'Top-5 %':<12} {'Latency (ms)':<15} {'Throughput':<12}" + ) + print("-" * 76) + + for model_name in ["fp16_baseline", "torch_fp8", "onnx_ptq_fp8"]: + acc = results.get(model_name, {}) + lat = latency_results.get(model_name, {}) if not args.skip_latency else {} + top1 = f"{acc['top1']:.2f}" if acc else "N/A" + top5 = f"{acc['top5']:.2f}" if acc else "N/A" + latency = f"{lat['latency_ms']:.3f}" if lat and lat.get("latency_ms") else "N/A" + throughput = f"{lat['throughput_qps']:.1f}" if lat and lat.get("throughput_qps") else "N/A" + print(f" {model_name:<23} {top1:<12} {top5:<12} {latency:<15} {throughput:<12}") + + print(f"\nResults saved to: {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/modelopt/onnx/export/fp8_exporter.py b/modelopt/onnx/export/fp8_exporter.py index dcae618dd0..8d402b39bd 100644 --- a/modelopt/onnx/export/fp8_exporter.py +++ b/modelopt/onnx/export/fp8_exporter.py @@ -17,6 +17,7 @@ import time +import numpy as np import onnx import onnx_graphsurgeon as gs import torch @@ -26,6 +27,11 @@ from .base_exporter import ONNXQuantExporter +# FP8 E4M3 max representable magnitude; softmax output in [0, 1] saturates exactly at 1.0 +# when using 1/448 as the Q scale. +_FP8_E4M3_MAX = 448.0 +_FP8_E4M3_SOFTMAX_SCALE = 1.0 / _FP8_E4M3_MAX + class FP8QuantExporter(ONNXQuantExporter): """Exporter for FP8 quantization.""" @@ -62,6 +68,8 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: # Fold constants is required since the scale is not constant yet. graph.cleanup().toposort().fold_constants().cleanup() + n_t_folded = 0 + for node in graph.nodes: if node.op == "TRT_FP8QuantizeLinear": # Should not remove input QDQ (only process weight quantization) @@ -74,9 +82,46 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: torch_scale = torch.from_numpy(scale.values) quantizer_name = scale.name.rsplit("/", 1)[0] dq_op = node.outputs[0].outputs[0] - assert dq_op.op == "TRT_FP8DequantizeLinear", ( - f"QDQ does not occur in pairs. You reached {dq_op.op}" - ) + if dq_op.op != "TRT_FP8DequantizeLinear": + raise RuntimeError( + f"QDQ does not occur in pairs. You reached {dq_op.op}" + ) + + # Pre-transpose constant weights if DQ feeds ``Transpose → MatMul`` (or + # ``Cast → Transpose → MatMul`` after fp16 conversion) so TRT sees DQ→MatMul. + # Control flow: scan candidates; a Cast-wrapped candidate is accepted only if it + # leads to a Transpose; a bare Transpose whose all consumers are MatMul wins and + # breaks the loop. Any other shape defaults `cast_to_remove` back to None and + # continues scanning. + transpose_to_remove = None + cast_to_remove = None + for candidate in list(dq_op.outputs[0].outputs): + if candidate.op == "Cast": + cast_to_remove = candidate + candidate = next( + (c for c in candidate.outputs[0].outputs if c.op == "Transpose"), + None, + ) + if candidate is None: + cast_to_remove = None + continue + if candidate.op != "Transpose": + cast_to_remove = None + continue + t_consumers = list(candidate.outputs[0].outputs) + # Only fold the transpose when every downstream consumer is MatMul; otherwise + # non-MatMul consumers would observe the un-transposed weights. + if t_consumers and all(c.op == "MatMul" for c in t_consumers): + perm = candidate.attrs.get("perm", None) + torch_weights = ( + torch_weights.permute(*perm).contiguous() + if perm is not None + else torch_weights.T.contiguous() + ) + transpose_to_remove = candidate + else: + cast_to_remove = None + break # Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8. numpy_weights = ( @@ -94,9 +139,23 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: dq_op.inputs[0] = onnx_weights_fp8 dq_op.op = "DequantizeLinear" dq_op.outputs[0].dtype = dq_op.inputs[1].dtype + dq_op.outputs[0].shape = list(numpy_weights.shape) + + if transpose_to_remove is not None: + t_out = transpose_to_remove.outputs[0] + for consumer in list(t_out.outputs): + for i, inp in enumerate(consumer.inputs): + if inp is t_out: + consumer.inputs[i] = dq_op.outputs[0] + transpose_to_remove.outputs.clear() + if cast_to_remove is not None: + cast_to_remove.outputs.clear() + n_t_folded += 1 graph.cleanup().toposort() end_time = time.time() + if n_t_folded > 0: + logger.info(f"Folded {n_t_folded} weight Transpose nodes during weight compression") print(f"fp8 qdq replaced with only dq completed in {end_time - start_time}s.") return gs.export_onnx(graph) @@ -121,7 +180,6 @@ def _quantize_conv_weights_to_fp8(graph: gs.Graph) -> int: Returns: Number of Conv weight DQ nodes inserted. """ - fp8_max = 448.0 count = 0 for node in list(graph.nodes): @@ -142,7 +200,7 @@ def _quantize_conv_weights_to_fp8(graph: gs.Graph) -> int: amax = torch_weights.abs().max().float() if amax == 0: continue - scale_val = (amax / fp8_max).item() + scale_val = (amax / _FP8_E4M3_MAX).item() # Quantize weights to FP8 (WAR: numpy doesn't support fp8) fp8_data = (torch_weights / scale_val).to(torch.float8_e4m3fn).view(torch.uint8).numpy() @@ -155,8 +213,6 @@ def _quantize_conv_weights_to_fp8(graph: gs.Graph) -> int: ) # Scale in FP16 — DQ output type matches scale dtype, must match activation type - import numpy as np - scale_constant = gs.Constant( node.name + "/weight_quantizer/scale", np.array(scale_val, dtype=np.float16), @@ -175,13 +231,219 @@ def _quantize_conv_weights_to_fp8(graph: gs.Graph) -> int: return count + @staticmethod + def _move_mul_before_qdq(graph: gs.Graph) -> int: + """Move attention-scaling Mul(const) from after DQ to before Q for TRT MatMul fusion. + + Handles both ``DQ → Mul → MatMul`` and ``DQ → Transpose → Mul → MatMul`` (K path). + """ + count = 0 + for mul_node in list(graph.nodes): + if mul_node.op != "Mul": + continue + + const_input = next( + (i for i in mul_node.inputs if isinstance(i, gs.Constant) and i.values.size == 1), + None, + ) + tensor_input = next( + (i for i in mul_node.inputs if not isinstance(i, gs.Constant)), None + ) + if const_input is None or tensor_input is None: + continue + if not (isinstance(tensor_input, gs.Variable) and len(tensor_input.inputs) == 1): + continue + + producer = tensor_input.inputs[0] + transpose_node = producer if producer.op == "Transpose" else None + dq_node = producer if producer.op == "DequantizeLinear" else None + if transpose_node is not None: + t_input = transpose_node.inputs[0] + if ( + isinstance(t_input, gs.Variable) + and len(t_input.inputs) == 1 + and t_input.inputs[0].op == "DequantizeLinear" + ): + dq_node = t_input.inputs[0] + if dq_node is None: + continue + + q_output = dq_node.inputs[0] + if ( + not isinstance(q_output, gs.Variable) + or len(q_output.inputs) != 1 + or q_output.inputs[0].op != "QuantizeLinear" + ): + continue + q_node = q_output.inputs[0] + q_input = q_node.inputs[0] + if not isinstance(q_input, gs.Variable): + continue + + mul_output = mul_node.outputs[0] + mul_consumers = list(mul_output.outputs) + # Require every consumer to be MatMul: rewiring all consumers to bypass the Mul + # would silently drop the scale for any non-MatMul branch. + if not mul_consumers or not all(c.op == "MatMul" for c in mul_consumers): + continue + + new_mul_output = gs.Variable( + q_input.name + "_scaled", dtype=q_input.dtype, shape=q_input.shape + ) + graph.nodes.append( + gs.Node( + op="Mul", + name=mul_node.name + "_moved", + inputs=[q_input, const_input], + outputs=[new_mul_output], + ) + ) + q_node.inputs[0] = new_mul_output + + replacement = ( + transpose_node.outputs[0] if transpose_node is not None else dq_node.outputs[0] + ) + for consumer in mul_consumers: + for i, inp in enumerate(consumer.inputs): + if inp is mul_output: + consumer.inputs[i] = replacement + mul_node.outputs.clear() + count += 1 + + graph.cleanup().toposort() + return count + + @staticmethod + def _move_transpose_before_qdq(graph: gs.Graph) -> int: + """Move Transpose from ``DQ → Transpose → MatMul`` to ``Transpose → Q → DQ → MatMul`` (K path).""" + count = 0 + for transpose_node in list(graph.nodes): + if transpose_node.op != "Transpose": + continue + + t_input = transpose_node.inputs[0] + if ( + not isinstance(t_input, gs.Variable) + or len(t_input.inputs) != 1 + or t_input.inputs[0].op != "DequantizeLinear" + ): + continue + dq_node = t_input.inputs[0] + + dq_input = dq_node.inputs[0] + if ( + not isinstance(dq_input, gs.Variable) + or len(dq_input.inputs) != 1 + or dq_input.inputs[0].op != "QuantizeLinear" + ): + continue + q_node = dq_input.inputs[0] + q_input = q_node.inputs[0] + if not isinstance(q_input, gs.Variable): + continue + + t_output = transpose_node.outputs[0] + t_consumers = list(t_output.outputs) + # Require every consumer to be MatMul: rewiring to dq_node.outputs[0] would drop + # the transpose for any non-MatMul branch, producing a wrong-shape tensor. + if not t_consumers or not all(c.op == "MatMul" for c in t_consumers): + continue + + new_t_output = gs.Variable(q_input.name + "_transposed", dtype=q_input.dtype) + graph.nodes.append( + gs.Node( + op="Transpose", + name=transpose_node.name + "_moved", + inputs=[q_input], + outputs=[new_t_output], + attrs=transpose_node.attrs, + ) + ) + q_node.inputs[0] = new_t_output + + for consumer in t_consumers: + for i, inp in enumerate(consumer.inputs): + if inp is t_output: + consumer.inputs[i] = dq_node.outputs[0] + transpose_node.outputs.clear() + count += 1 + + graph.cleanup().toposort() + return count + + @staticmethod + def _insert_qdq_after_softmax(graph: gs.Graph) -> int: + """Insert FP8 Q→DQ on Softmax outputs feeding MatMul (required by TRT MHA fusion). + + Torch export does not quantize softmax output; ``_FP8_E4M3_SOFTMAX_SCALE`` (1/448) + saturates exactly at 1.0 (softmax range is [0, 1]) while covering the full FP8 E4M3 + representable range. Only applied when every Softmax consumer is a MatMul so we do + not insert quantization error on unrelated branches. + """ + count = 0 + for softmax_node in list(graph.nodes): + if softmax_node.op != "Softmax": + continue + softmax_output = softmax_node.outputs[0] + consumers = list(softmax_output.outputs) + if not consumers or not all(c.op == "MatMul" for c in consumers): + continue + if any(c.op == "QuantizeLinear" for c in consumers): + continue + + # Match scale dtype to the graph's current float dtype so TRT stronglyTyped + # sees consistent Q/DQ types with the surrounding compute. + scale_dtype = softmax_output.dtype if softmax_output.dtype is not None else np.float32 + scale_val = np.array(_FP8_E4M3_SOFTMAX_SCALE, dtype=scale_dtype) + scale_constant = gs.Constant(softmax_node.name + "/softmax_q_scale", scale_val) + dq_scale_constant = gs.Constant( + softmax_node.name + "/softmax_dq_scale", scale_val.copy() + ) + + zp_tensor = onnx.TensorProto() + zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN + zp_tensor.dims.extend([1]) + zp_tensor.raw_data = b"\x00" + zp_constant = gs.Constant( + softmax_node.name + "/softmax_q_zero_point", LazyValues(zp_tensor) + ) + + q_output = gs.Variable(softmax_node.name + "/q_output") + dq_output = gs.Variable(softmax_node.name + "/dq_output", dtype=softmax_output.dtype) + q_node = gs.Node( + op="QuantizeLinear", + name=softmax_node.name + "/QuantizeLinear", + inputs=[softmax_output, scale_constant, zp_constant], + outputs=[q_output], + attrs={"saturate": 1}, + ) + dq_node = gs.Node( + op="DequantizeLinear", + name=softmax_node.name + "/DequantizeLinear", + inputs=[q_output, dq_scale_constant], + outputs=[dq_output], + ) + graph.nodes.extend([q_node, dq_node]) + + for consumer in consumers: + if consumer is q_node: + continue + for i, inp in enumerate(consumer.inputs): + if inp is softmax_output: + consumer.inputs[i] = dq_output + count += 1 + + graph.cleanup().toposort() + return count + @staticmethod def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Post-processes the ONNX model for FP8 quantization. - Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear and + Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear, adds FP8 weight DQ for Conv layers whose weight quantizers were disabled during - TorchScript export. + TorchScript export, and rewrites attention scaling / K-transpose / softmax-output + patterns so TRT can fuse DQ into the attention MatMul kernels. Args: onnx_model: The ONNX model containing TRT_FP8 quantization nodes. @@ -223,5 +485,15 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: if count > 0: logger.info(f"Inserted FP8 weight DequantizeLinear for {count} Conv nodes") + # Attention-aware rewrites so TRT can fuse DQ into the attention MatMuls. + n_mul = FP8QuantExporter._move_mul_before_qdq(graph) + n_t = FP8QuantExporter._move_transpose_before_qdq(graph) + n_sm = FP8QuantExporter._insert_qdq_after_softmax(graph) + if n_mul or n_t or n_sm: + logger.info( + f"Attention QDQ rewrites: moved {n_mul} Mul, {n_t} Transpose; " + f"inserted QDQ on {n_sm} Softmax outputs" + ) + graph.cleanup().toposort() return gs.export_onnx(graph) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index ac93bc2a26..6b38d5bbdb 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -1415,6 +1415,74 @@ def _bypass_cast_node(model: onnx.ModelProto, node: onnx.NodeProto) -> None: consumer.input[i] = input_tensor +_DQ_OPS = {"DequantizeLinear", "TRT_FP8DequantizeLinear"} +_Q_OPS = {"QuantizeLinear", "TRT_FP8QuantizeLinear"} + + +def _scale_fp32_to_fp16(scale_init: onnx.TensorProto) -> None: + """Convert a scalar Q/DQ scale initializer in-place from FP32 to FP16. + + Warns if any non-zero scale saturates to 0/inf in FP16 (out of FP16 representable range). + """ + if scale_init.data_type != onnx.TensorProto.FLOAT: + return + scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32) + if not scale_data.size: + scale_data = np.array(scale_init.float_data, dtype=np.float32) + fp16_data = scale_data.astype(np.float16) + if np.any(np.isinf(fp16_data)) or ( + np.any(fp16_data == 0) and np.any(scale_data != 0) + ): + logger.warning( + f"Q/DQ scale '{scale_init.name}' overflows or underflows when cast to FP16" + ) + scale_init.data_type = onnx.TensorProto.FLOAT16 + scale_init.raw_data = fp16_data.tobytes() + del scale_init.float_data[:] + + +def fold_q_fp16_to_fp32_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Remove ``Cast(FP16→FP32) → Q`` patterns inserted by ``convert_float_to_float16``. + + The Q scale is rewritten to FP16 so Q consumes the FP16 graph directly. Skipped for + opsets below ``BASE_MIN_OPSET`` since FP16 Q scales require opset >= 19. + """ + if get_opset_version(onnx_model) < BASE_MIN_OPSET: + logger.debug( + f"Skipping fold_q_fp16_to_fp32_casts: opset < {BASE_MIN_OPSET} (FP16 Q scale unsupported)" + ) + return onnx_model + + consumer_map: dict[str, list[onnx.NodeProto]] = {} + for node in onnx_model.graph.node: + for inp in node.input: + consumer_map.setdefault(inp, []).append(node) + initializers = {init.name: init for init in onnx_model.graph.initializer} + + to_remove = [] + for node in onnx_model.graph.node: + if node.op_type != "Cast": + continue + cast_to = next((a.i for a in node.attribute if a.name == "to"), None) + if cast_to != onnx.TensorProto.FLOAT: + continue + consumers = consumer_map.get(node.output[0], []) + if not consumers or not all(c.op_type in _Q_OPS for c in consumers): + continue + + for q_node in consumers: + if len(q_node.input) >= 2 and q_node.input[1] in initializers: + _scale_fp32_to_fp16(initializers[q_node.input[1]]) + + _bypass_cast_node(onnx_model, node) + to_remove.append(node) + + logger.debug(f"Folded {len(to_remove)} Cast(FP16->FP32) -> Q patterns") + for node in to_remove: + onnx_model.graph.node.remove(node) + return onnx_model + + def _is_foldable_constant_cast_pattern(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: """Check if a Constant -> Cast pattern can be folded.""" assert node.op_type == "Cast" @@ -1523,7 +1591,12 @@ def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: Returns: The ONNX model with Cast nodes removed and DQ outputs set to FP16. """ - import numpy as np + if get_opset_version(onnx_model) < BASE_MIN_OPSET: + logger.debug( + f"Skipping fold_dq_fp32_to_fp16_casts: opset < {BASE_MIN_OPSET} " + "(FP16 DQ scale unsupported)" + ) + return onnx_model dq_ops = {"DequantizeLinear", "TRT_FP8DequantizeLinear"} @@ -1623,6 +1696,13 @@ def fold_qdq_scale_fp16_to_fp32_casts(onnx_model: onnx.ModelProto) -> onnx.Model Returns: The ONNX model with redundant scale-path casts removed. """ + if get_opset_version(onnx_model) < BASE_MIN_OPSET: + logger.debug( + f"Skipping fold_qdq_scale_fp16_to_fp32_casts: opset < {BASE_MIN_OPSET} " + "(FP16 Q/DQ scale unsupported)" + ) + return onnx_model + qdq_ops = { "QuantizeLinear", "DequantizeLinear", diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 9ec110b788..01fb754bba 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -48,6 +48,7 @@ change_casts_to_fp16, check_model_uses_external_data, fold_dq_fp32_to_fp16_casts, + fold_q_fp16_to_fp32_casts, fold_qdq_scale_fp16_to_fp32_casts, get_input_names, get_input_shapes, @@ -663,6 +664,11 @@ def get_onnx_bytes_and_metadata( onnx_opt_graph = remove_redundant_casts(onnx_opt_graph) + # Remove Cast nodes around Q/DQ for optimal TRT fusion + if is_fp8_quantized(model): + onnx_opt_graph = fold_q_fp16_to_fp32_casts(onnx_opt_graph) + onnx_opt_graph = fold_dq_fp32_to_fp16_casts(onnx_opt_graph) + # TensorRT expects all scales to be postive onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index 05efe48842..7b42dce578 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -216,56 +216,36 @@ def _fp8_quantize( g: "GraphContext", inputs: torch.Value, scale_inv: float, - trt_high_precision_dtype: str, ): """Helper Function for Quantization.""" + # Emit the scale in the native input dtype so no Cast is inserted between the + # graph and Q/DQ (Cast nodes block TRT from fusing DQ into the MatMul kernel). output_shape = sym_help._get_tensor_sizes(inputs) - - # TRT StronglyType only supports FP16 QDQs - # custom ops, so cast the input if needed. - input_type = inputs.type().scalarType() - assert trt_high_precision_dtype in (input_type, "Float"), ( - "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float." - ) - if trt_high_precision_dtype != input_type: - inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[trt_high_precision_dtype]) - scale = g.op( "Constant", - value_t=torch.tensor(scale_inv).to(torch_dtype_map[trt_high_precision_dtype]), + value_t=torch.tensor(scale_inv).to(torch_dtype_map[inputs.type().scalarType()]), ) - q_op = g.op("trt::TRT_FP8QuantizeLinear", inputs, scale).setType( + return g.op("trt::TRT_FP8QuantizeLinear", inputs, scale).setType( inputs.type().with_dtype(torch.uint8).with_sizes(output_shape) ) - return q_op def _fp8_dequantize( g: "GraphContext", inputs: torch.Value, scale_inv: float, - trt_high_precision_dtype: str, otype: str | None = None, ): """Helper Function for Dequantization.""" output_shape = sym_help._get_tensor_sizes(inputs) - assert trt_high_precision_dtype in (otype, "Float"), ( - "TRT StronglyType requires both weights and amax to be in the BF16/FP16, or the QDQ in Float." - ) scale = g.op( "Constant", value_t=torch.tensor(scale_inv, dtype=torch_dtype_map[otype]), # type: ignore[index] ) - out = g.op("trt::TRT_FP8DequantizeLinear", inputs, scale).setType( - inputs.type().with_dtype(torch_dtype_map[trt_high_precision_dtype]).with_sizes(output_shape) + return g.op("trt::TRT_FP8DequantizeLinear", inputs, scale).setType( + inputs.type().with_dtype(torch_dtype_map[otype]).with_sizes(output_shape) # type: ignore[index] ) - # DQ outputs are currently constrained to FP32 due to a similar limitation in ORT - # custom ops, so cast the output if needed. - if trt_high_precision_dtype != otype: - out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index] - return out - def export_fp8( g: "GraphContext", @@ -273,14 +253,17 @@ def export_fp8( amax: float, trt_high_precision_dtype: str | None, ): - """Export quantized model to FP8 ONNX.""" + """Export quantized model to FP8 ONNX. + + ``trt_high_precision_dtype`` is accepted for API compatibility but unused: Q/DQ now + emit scales in the native input dtype, so no intermediate Cast is required. + """ + del trt_high_precision_dtype scale = 1.0 if amax is None else 448.0 / float(amax) otype = inputs.type().scalarType() - if trt_high_precision_dtype is None: - trt_high_precision_dtype = otype - q_tensor = _fp8_quantize(g, inputs, 1.0 / scale, trt_high_precision_dtype) - return _fp8_dequantize(g, q_tensor, 1.0 / scale, trt_high_precision_dtype, otype) + q_tensor = _fp8_quantize(g, inputs, 1.0 / scale) + return _fp8_dequantize(g, q_tensor, 1.0 / scale, otype) def scaled_dot_product_attention( diff --git a/modelopt/torch/quantization/nn/__init__.py b/modelopt/torch/quantization/nn/__init__.py index ca7082eb1c..af9490c831 100644 --- a/modelopt/torch/quantization/nn/__init__.py +++ b/modelopt/torch/quantization/nn/__init__.py @@ -19,6 +19,7 @@ from .modules.quant_batchnorm import * from .modules.quant_conv import * from .modules.quant_instancenorm import * +from .modules.quant_layernorm import * from .modules.quant_linear import * from .modules.quant_module import * from .modules.quant_pooling import * diff --git a/modelopt/torch/quantization/nn/modules/quant_layernorm.py b/modelopt/torch/quantization/nn/modules/quant_layernorm.py new file mode 100644 index 0000000000..bea9d892eb --- /dev/null +++ b/modelopt/torch/quantization/nn/modules/quant_layernorm.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registers ``torch.nn.LayerNorm`` with ``QuantInputBase`` so its output quantizer is +honored during quantization. Required for FP8 attention fusion where a single LayerNorm +output QDQ is shared across all downstream Q/K/V/FC consumers (instead of repeating it +on each input), which enables TRT to fuse DQ into the attention MatMul kernels.""" + +import torch.nn as nn + +from .quant_module import QuantInputBase, QuantModuleRegistry + +QuantModuleRegistry.register({nn.LayerNorm: "nn.LayerNorm"})(QuantInputBase) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 59bcd215bb..596fe8208b 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -286,9 +286,24 @@ def register_hf_attentions_on_the_fly(model): attention_cls = set() registered_attn_module = False + + # Skip attention wrappers that contain a nested "Attention" child on this specific + # instance (e.g. ViTAttention wraps ViTSelfAttention). Patching both would + # double-quantize eager_attention_forward. Checked per-instance (not by class) so a + # class reused as both wrapper and leaf is not dropped everywhere. In a 3-level + # hierarchy (Outer → Middle → Inner), both Outer and Middle are treated as wrappers + # and only Inner is registered. + def _wraps_nested_attention(module): + return any( + child is not module and type(child).__name__.endswith("Attention") + for _, child in module.named_modules() + ) + for name, module in model.named_modules(): # Only register attention classes that are from Huggingface transformers if type(module).__name__.endswith("Attention"): + if _wraps_nested_attention(module): + continue attention_type = _QuantAttention.get_attn_type(module) # Add modules to be registered only if they arent already registered if (