Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
538 changes: 538 additions & 0 deletions examples/torch_onnx/vit_mha_quantization.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update changelog for 0.45 to mention this new feature details?

Large diffs are not rendered by default.

257 changes: 254 additions & 3 deletions modelopt/onnx/export/fp8_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
# Fold constants is required since the scale is not constant yet.
graph.cleanup().toposort().fold_constants().cleanup()

n_t_folded = 0

for node in graph.nodes:
if node.op == "TRT_FP8QuantizeLinear":
# Should not remove input QDQ (only process weight quantization)
Expand All @@ -78,6 +80,33 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
f"QDQ does not occur in pairs. You reached {dq_op.op}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug (existing but expanded): This assert will be stripped with -O flag. Per codebase conventions, use raise ValueError(...) or RuntimeError(...) instead for runtime validation:

if dq_op.op != "TRT_FP8DequantizeLinear":
    raise RuntimeError(f"QDQ does not occur in pairs. You reached {dq_op.op}")

)

# Pre-transpose constant weights if DQ feeds ``Transpose → MatMul`` (or
# ``Cast → Transpose → MatMul`` after fp16 conversion) so TRT sees DQ→MatMul.
transpose_to_remove = None
cast_to_remove = None
for candidate in list(dq_op.outputs[0].outputs):
if candidate.op == "Cast":
cast_to_remove = candidate
candidate = next(
(c for c in candidate.outputs[0].outputs if c.op == "Transpose"),
None,
)
if candidate is None:
cast_to_remove = None
continue
if candidate.op != "Transpose":
cast_to_remove = None
continue
if any(c.op == "MatMul" for c in candidate.outputs[0].outputs):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential issue: When dq_op.outputs[0] has multiple consumers and one is Cast followed by Transpose, the loop breaks after the first candidate. If the first consumer is a Cast that doesn't lead to a Transpose, cast_to_remove is set to None and then continue re-enters the loop at the next candidate — but break at line 109 exits regardless. This means if the first candidate is a non-Cast/non-Transpose node, we skip checking remaining consumers. Consider restructuring to check all candidates:

for candidate in list(dq_op.outputs[0].outputs):
    if candidate.op == "Cast":
        ...
        if candidate is None:
            cast_to_remove = None
            continue  # <-- this continues to next candidate, good
    if candidate.op != "Transpose":
        cast_to_remove = None
        continue  # <-- this also continues, good  
    ...
    break  # <-- but this break exits even if we didn't find a match

Actually looking more carefully, the logic resets cast_to_remove = None when it's not a Transpose, then continues — so it does check the next candidate. But if candidate.op is a Transpose without a MatMul consumer, you break without resetting state. The logic works but is fragile. A comment would help.

perm = candidate.attrs.get("perm", None)
torch_weights = (
torch_weights.permute(*perm).contiguous()
if perm is not None
else torch_weights.T.contiguous()
)
transpose_to_remove = candidate
break
Comment on lines +100 to +108
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Restrict these rewrites to pure MatMul fanout.

Each transform fires when it sees any MatMul consumer, then rewires all consumers and clears the original outputs. If one of these tensors also feeds a non-MatMul branch, that branch will now observe the untransposed / unscaled / quantized value instead of the original one. Please either require all(...) consumers to match the intended pattern or only rewrite the specific MatMul edges.

Also applies to: 128-137, 196-220, 255-275, 334-339

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 100 - 108, The code
currently uses any(c.op == "MatMul" for c in candidate.outputs[0].outputs) and
then rewires/clears all consumers which breaks non-MatMul branches; change the
logic to require all(c.op == "MatMul" for c in candidate.outputs[0].outputs)
before performing the global rewrite OR, preferably, only rewrite the specific
MatMul edges: iterate candidate.outputs[0].outputs, for each consumer c with
c.op == "MatMul" rewire that consumer's input to use the
transposed/scaled/quantized tensor and leave other consumers untouched, and do
not clear original outputs (update transpose_to_remove only when all downstream
edges have been safely redirected). Apply the same fix pattern to the other
rewrite sites that manipulate torch_weights, perm, transpose_to_remove, and
similar MatMul-aware transforms.


# Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8.
numpy_weights = (
(torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy()
Expand All @@ -94,20 +123,232 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
dq_op.inputs[0] = onnx_weights_fp8
dq_op.op = "DequantizeLinear"
dq_op.outputs[0].dtype = dq_op.inputs[1].dtype
dq_op.outputs[0].shape = list(numpy_weights.shape)

if transpose_to_remove is not None:
t_out = transpose_to_remove.outputs[0]
for consumer in list(t_out.outputs):
for i, inp in enumerate(consumer.inputs):
if inp is t_out:
consumer.inputs[i] = dq_op.outputs[0]
transpose_to_remove.outputs.clear()
if cast_to_remove is not None:
cast_to_remove.outputs.clear()
n_t_folded += 1

graph.cleanup().toposort()
end_time = time.time()
if n_t_folded > 0:
logger.info(f"Folded {n_t_folded} weight Transpose nodes during weight compression")
print(f"fp8 qdq replaced with only dq completed in {end_time - start_time}s.")

return gs.export_onnx(graph)

@staticmethod
def _move_mul_before_qdq(graph: gs.Graph) -> int:
"""Move attention-scaling Mul(const) from after DQ to before Q for TRT MatMul fusion.

Handles both ``DQ → Mul → MatMul`` and ``DQ → Transpose → Mul → MatMul`` (K path).
"""
count = 0
for mul_node in list(graph.nodes):
if mul_node.op != "Mul":
continue

const_input = next(
(i for i in mul_node.inputs if isinstance(i, gs.Constant) and i.values.size == 1),
None,
)
tensor_input = next(
(i for i in mul_node.inputs if not isinstance(i, gs.Constant)), None
)
if const_input is None or tensor_input is None:
continue
if not (isinstance(tensor_input, gs.Variable) and len(tensor_input.inputs) == 1):
continue

producer = tensor_input.inputs[0]
transpose_node = producer if producer.op == "Transpose" else None
dq_node = producer if producer.op == "DequantizeLinear" else None
if transpose_node is not None:
t_input = transpose_node.inputs[0]
if (
isinstance(t_input, gs.Variable)
and len(t_input.inputs) == 1
and t_input.inputs[0].op == "DequantizeLinear"
):
dq_node = t_input.inputs[0]
if dq_node is None:
continue

q_output = dq_node.inputs[0]
if (
not isinstance(q_output, gs.Variable)
or len(q_output.inputs) != 1
or q_output.inputs[0].op != "QuantizeLinear"
):
continue
q_node = q_output.inputs[0]
q_input = q_node.inputs[0]
if not isinstance(q_input, gs.Variable):
continue

mul_output = mul_node.outputs[0]
if not any(c.op == "MatMul" for c in mul_output.outputs):
continue

new_mul_output = gs.Variable(
q_input.name + "_scaled", dtype=q_input.dtype, shape=q_input.shape
)
graph.nodes.append(
gs.Node(
op="Mul",
name=mul_node.name + "_moved",
inputs=[q_input, const_input],
outputs=[new_mul_output],
)
)
q_node.inputs[0] = new_mul_output

replacement = (
transpose_node.outputs[0] if transpose_node is not None else dq_node.outputs[0]
)
for consumer in list(mul_output.outputs):
for i, inp in enumerate(consumer.inputs):
if inp is mul_output:
consumer.inputs[i] = replacement
mul_node.outputs.clear()
count += 1

graph.cleanup().toposort()
return count

@staticmethod
def _move_transpose_before_qdq(graph: gs.Graph) -> int:
"""Move Transpose from ``DQ → Transpose → MatMul`` to ``Transpose → Q → DQ → MatMul`` (K path)."""
count = 0
for transpose_node in list(graph.nodes):
if transpose_node.op != "Transpose":
continue

t_input = transpose_node.inputs[0]
if (
not isinstance(t_input, gs.Variable)
or len(t_input.inputs) != 1
or t_input.inputs[0].op != "DequantizeLinear"
):
continue
dq_node = t_input.inputs[0]

dq_input = dq_node.inputs[0]
if (
not isinstance(dq_input, gs.Variable)
or len(dq_input.inputs) != 1
or dq_input.inputs[0].op != "QuantizeLinear"
):
continue
q_node = dq_input.inputs[0]
q_input = q_node.inputs[0]
if not isinstance(q_input, gs.Variable):
continue

t_output = transpose_node.outputs[0]
if not any(c.op == "MatMul" for c in t_output.outputs):
continue

new_t_output = gs.Variable(q_input.name + "_transposed", dtype=q_input.dtype)
graph.nodes.append(
gs.Node(
op="Transpose",
name=transpose_node.name + "_moved",
inputs=[q_input],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded 1.0/448.0 is correct (FP8 E4M3 max = 448, softmax output ∈ [0,1]), and the docstring explains this well. However, consider defining this as a named constant (e.g., _FP8_E4M3_SOFTMAX_SCALE = 1.0 / 448.0) since the same value is also used in export_onnx.py (export_fp8_mha uses 1.0 for softmax amax which becomes 448.0/1.0 then 1.0/448.0). This would make the connection explicit.

outputs=[new_t_output],
attrs=transpose_node.attrs,
)
)
q_node.inputs[0] = new_t_output

for consumer in list(t_output.outputs):
for i, inp in enumerate(consumer.inputs):
if inp is t_output:
consumer.inputs[i] = dq_node.outputs[0]
transpose_node.outputs.clear()
count += 1

graph.cleanup().toposort()
return count

@staticmethod
def _insert_qdq_after_softmax(graph: gs.Graph) -> int:
"""Insert FP8 Q→DQ on Softmax outputs feeding MatMul (required by TRT MHA fusion).

Torch export does not quantize softmax output; scale=1/448 saturates exactly at 1.0
(softmax range is [0, 1]) while covering the full FP8 E4M3 representable range.
"""
import numpy as np

count = 0
for softmax_node in list(graph.nodes):
if softmax_node.op != "Softmax":
continue
softmax_output = softmax_node.outputs[0]
if not any(c.op == "MatMul" for c in softmax_output.outputs):
continue
if any(c.op == "QuantizeLinear" for c in softmax_output.outputs):
continue

# Match scale dtype to the graph's current float dtype so TRT stronglyTyped
# sees consistent Q/DQ types with the surrounding compute.
scale_dtype = softmax_output.dtype if softmax_output.dtype is not None else np.float32
scale_val = np.array(1.0 / 448.0, dtype=scale_dtype)
scale_constant = gs.Constant(softmax_node.name + "/softmax_q_scale", scale_val)
dq_scale_constant = gs.Constant(
softmax_node.name + "/softmax_dq_scale", scale_val.copy()
)

zp_tensor = onnx.TensorProto()
zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
zp_tensor.dims.extend([1])
zp_tensor.raw_data = b"\x00"
zp_constant = gs.Constant(
softmax_node.name + "/softmax_q_zero_point", LazyValues(zp_tensor)
)

q_output = gs.Variable(softmax_node.name + "/q_output")
dq_output = gs.Variable(softmax_node.name + "/dq_output", dtype=softmax_output.dtype)
q_node = gs.Node(
op="QuantizeLinear",
name=softmax_node.name + "/QuantizeLinear",
inputs=[softmax_output, scale_constant, zp_constant],
outputs=[q_output],
attrs={"saturate": 1},
)
dq_node = gs.Node(
op="DequantizeLinear",
name=softmax_node.name + "/DequantizeLinear",
inputs=[q_output, dq_scale_constant],
outputs=[dq_output],
)
graph.nodes.extend([q_node, dq_node])

for consumer in list(softmax_output.outputs):
if consumer is q_node:
continue
for i, inp in enumerate(consumer.inputs):
if inp is softmax_output:
consumer.inputs[i] = dq_output
count += 1

graph.cleanup().toposort()
return count

@staticmethod
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Post-processes the ONNX model for FP8 quantization.

Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear:
- TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1
- TRT_FP8DequantizeLinear -> DequantizeLinear
Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear and
rewrites attention scaling / K-transpose / softmax-output patterns so TRT
can fuse DQ into the attention MatMul kernels.

Args:
onnx_model: The ONNX model containing TRT_FP8 quantization nodes.
Expand Down Expand Up @@ -144,5 +385,15 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear"
)

# Attention-aware rewrites so TRT can fuse DQ into the attention MatMuls.
n_mul = FP8QuantExporter._move_mul_before_qdq(graph)
n_t = FP8QuantExporter._move_transpose_before_qdq(graph)
n_sm = FP8QuantExporter._insert_qdq_after_softmax(graph)
if n_mul or n_t or n_sm:
logger.info(
f"Attention QDQ rewrites: moved {n_mul} Mul, {n_t} Transpose; "
f"inserted QDQ on {n_sm} Softmax outputs"
)

graph.cleanup().toposort()
return gs.export_onnx(graph)
90 changes: 90 additions & 0 deletions modelopt/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,96 @@ def _bypass_cast_node(model: onnx.ModelProto, node: onnx.NodeProto) -> None:
consumer.input[i] = input_tensor


_DQ_OPS = {"DequantizeLinear", "TRT_FP8DequantizeLinear"}
_Q_OPS = {"QuantizeLinear", "TRT_FP8QuantizeLinear"}


def _scale_fp32_to_fp16(scale_init: onnx.TensorProto) -> None:
"""Convert a scalar Q/DQ scale initializer in-place from FP32 to FP16."""
import numpy as np

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: np is already imported at the top of the file. The local import numpy as np is unnecessary.

if scale_init.data_type != onnx.TensorProto.FLOAT:
return
scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32)
if not scale_data.size:
scale_data = np.array(scale_init.float_data, dtype=np.float32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential precision issue: _scale_fp32_to_fp16 does not check if the FP32 scale value is outside FP16 representable range. If a scale is very small (< ~5.96e-8 for FP16 subnormals) or very large (> 65504), the conversion will silently produce 0 or inf. Consider adding a warning:

import warnings
fp16_val = scale_data.astype(np.float16)
if np.any(np.isinf(fp16_val)) or np.any(fp16_val == 0) and np.any(scale_data != 0):
    logger.warning("Q/DQ scale overflows/underflows when converting to FP16")

scale_init.data_type = onnx.TensorProto.FLOAT16
scale_init.raw_data = scale_data.astype(np.float16).tobytes()
del scale_init.float_data[:]


def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Remove ``DQ → Cast(FP32→FP16)`` patterns inserted by ``convert_float_to_float16``.

The DQ scale is rewritten to FP16 so DQ natively produces FP16 output, enabling
TRT to fuse DQ directly into the downstream compute op.
"""
producer_map = {out: node for node in onnx_model.graph.node for out in node.output}
initializers = {init.name: init for init in onnx_model.graph.initializer}

to_remove = []
for node in onnx_model.graph.node:
if node.op_type != "Cast":
continue
cast_to = next((a.i for a in node.attribute if a.name == "to"), None)
if cast_to != onnx.TensorProto.FLOAT16:
continue
producer = producer_map.get(node.input[0])
if producer is None or producer.op_type not in _DQ_OPS:
continue

if len(producer.input) >= 2 and producer.input[1] in initializers:
_scale_fp32_to_fp16(initializers[producer.input[1]])

_bypass_cast_node(onnx_model, node)
to_remove.append(node)

for vi in onnx_model.graph.value_info:
if vi.name == producer.output[0]:
vi.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
break

logger.debug(f"Folded {len(to_remove)} DQ -> Cast(FP32->FP16) patterns")
for node in to_remove:
onnx_model.graph.node.remove(node)
return onnx_model


def fold_q_fp16_to_fp32_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Remove ``Cast(FP16→FP32) → Q`` patterns inserted by ``convert_float_to_float16``.

The Q scale is rewritten to FP16 so Q consumes the FP16 graph directly.
"""
consumer_map: dict[str, list[onnx.NodeProto]] = {}
for node in onnx_model.graph.node:
for inp in node.input:
consumer_map.setdefault(inp, []).append(node)
initializers = {init.name: init for init in onnx_model.graph.initializer}

to_remove = []
for node in onnx_model.graph.node:
if node.op_type != "Cast":
continue
cast_to = next((a.i for a in node.attribute if a.name == "to"), None)
if cast_to != onnx.TensorProto.FLOAT:
continue
consumers = consumer_map.get(node.output[0], [])
if not consumers or not all(c.op_type in _Q_OPS for c in consumers):
continue

for q_node in consumers:
if len(q_node.input) >= 2 and q_node.input[1] in initializers:
_scale_fp32_to_fp16(initializers[q_node.input[1]])

_bypass_cast_node(onnx_model, node)
to_remove.append(node)

logger.debug(f"Folded {len(to_remove)} Cast(FP16->FP32) -> Q patterns")
for node in to_remove:
onnx_model.graph.node.remove(node)
return onnx_model
Comment on lines +1422 to +1505
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

What is the minimum ONNX opset version that allows FLOAT16 scale tensors for QuantizeLinear and DequantizeLinear?

💡 Result:

The minimum ONNX opset version that allows FLOAT16 scale tensors for both QuantizeLinear and DequantizeLinear is 19.

Citations:


🏁 Script executed:

# Check for BASE_MIN_OPSET and get_opset_version in the file
rg -n "BASE_MIN_OPSET|get_opset_version" modelopt/onnx/utils.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 335


🏁 Script executed:

# Check the actual implementation of fold_dq_fp32_to_fp16_casts and fold_q_fp16_to_fp32_casts for guards
sed -n '1422,1505p' modelopt/onnx/utils.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 4025


🏁 Script executed:

# Search for where these fold functions are called
rg -n "fold_dq_fp32_to_fp16_casts|fold_q_fp16_to_fp32_casts" modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 234


🏁 Script executed:

# Look for opset version checks near the fold function calls
rg -B 10 -A 10 "fold_dq_fp32_to_fp16_casts|fold_q_fp16_to_fp32_casts" modelopt/onnx/utils.py | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 1836


🏁 Script executed:

# Search for calls to these fold functions throughout the codebase
rg -n "fold_dq_fp32_to_fp16_casts|fold_q_fp16_to_fp32_casts" modelopt/ --type py -B 5 -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 3546


🏁 Script executed:

# Also check for the convert_float_to_float16 function mentioned in the docstring
rg -n "convert_float_to_float16" modelopt/onnx/utils.py -B 5 -A 20

Repository: NVIDIA/Model-Optimizer

Length of output: 2454


🏁 Script executed:

# Check the context where fold functions are called in torch_onnx.py
sed -n '610,640p' modelopt/torch/_deploy/utils/torch_onnx.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 1714


🏁 Script executed:

# Look for opset validation in the export_to_onnx function or its caller
grep -n "def export_to_onnx\|opset\|BASE_MIN_OPSET" modelopt/torch/_deploy/utils/torch_onnx.py | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 232


🏁 Script executed:

# Check what the is_fp8_quantized function does
rg -n "def is_fp8_quantized" modelopt/ -A 10 --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 1147


🏁 Script executed:

# Check the convert_float_to_float16 function and how it's used
rg -n "def convert_float_to_float16\|convert_float_to_float16" modelopt/onnx/utils.py -B 3 -A 15 | head -60

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check for any opset validation in export_to_onnx before the fold calls
sed -n '450,630p' modelopt/torch/_deploy/utils/torch_onnx.py | grep -n "opset\|get_opset_version\|BASE_MIN_OPSET"

Repository: NVIDIA/Model-Optimizer

Length of output: 230


🏁 Script executed:

# Look at DEFAULT_ONNX_OPSET to understand the default opset value
rg -n "DEFAULT_ONNX_OPSET" modelopt/ --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 459


🏁 Script executed:

# Search for convert_float_to_float16 function across the entire codebase
rg -n "convert_float_to_float16" modelopt/ --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 506


🏁 Script executed:

# Check what the default export behavior does - look at complete export_to_onnx signature and opset handling
sed -n '450,580p' modelopt/torch/_deploy/utils/torch_onnx.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 6658


🏁 Script executed:

# Check the entire export_to_onnx function flow to see if there's any opset validation before fold calls
sed -n '460,650p' modelopt/torch/_deploy/utils/torch_onnx.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 9684


🏁 Script executed:

# Search for any opset validation or assertion related to BASE_MIN_OPSET in the export path
rg -n "BASE_MIN_OPSET" modelopt/torch/_deploy/ --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


Add opset guard to FP16 scale rewrite helpers.

These helpers rewrite Q/DQ scale initializers to FLOAT16 unconditionally, but FLOAT16 scales are only valid from opset 19 onward. When FP8 quantization is enabled, a user can request a lower opset (e.g., onnx_opset=18) via the export API, which will cause these functions to silently produce an invalid model. Check get_opset_version(onnx_model) >= BASE_MIN_OPSET before mutating initializers, or skip the fold operation entirely for lower opsets.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1422 - 1505, The fold helpers
unconditionally convert Q/DQ scale initializers to FLOAT16 which is invalid for
opsets < BASE_MIN_OPSET; update _scale_fp32_to_fp16, fold_dq_fp32_to_fp16_casts
and fold_q_fp16_to_fp32_casts to guard the mutation by checking
get_opset_version(onnx_model) (or the model passed in) and only perform the
FP32→FP16 rewrite when get_opset_version(onnx_model) >= BASE_MIN_OPSET; if the
check fails, skip mutating initializers and skip folding the cast nodes (i.e.,
return the model unchanged or continue without calling
_scale_fp32_to_fp16/_bypass_cast_node), using the existing function names
(_scale_fp32_to_fp16, fold_dq_fp32_to_fp16_casts, fold_q_fp16_to_fp32_casts) and
constants (BASE_MIN_OPSET) to locate where to add the guard.



def _is_foldable_constant_cast_pattern(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
"""Check if a Constant -> Cast pattern can be folded."""
assert node.op_type == "Cast"
Expand Down
Loading
Loading