-
Notifications
You must be signed in to change notification settings - Fork 362
Add FP8 MHA quantization support for HuggingFace ViT #1289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,8 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |
| # Fold constants is required since the scale is not constant yet. | ||
| graph.cleanup().toposort().fold_constants().cleanup() | ||
|
|
||
| n_t_folded = 0 | ||
|
|
||
| for node in graph.nodes: | ||
| if node.op == "TRT_FP8QuantizeLinear": | ||
| # Should not remove input QDQ (only process weight quantization) | ||
|
|
@@ -78,6 +80,33 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |
| f"QDQ does not occur in pairs. You reached {dq_op.op}" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug (existing but expanded): This if dq_op.op != "TRT_FP8DequantizeLinear":
raise RuntimeError(f"QDQ does not occur in pairs. You reached {dq_op.op}") |
||
| ) | ||
|
|
||
| # Pre-transpose constant weights if DQ feeds ``Transpose → MatMul`` (or | ||
| # ``Cast → Transpose → MatMul`` after fp16 conversion) so TRT sees DQ→MatMul. | ||
| transpose_to_remove = None | ||
| cast_to_remove = None | ||
| for candidate in list(dq_op.outputs[0].outputs): | ||
| if candidate.op == "Cast": | ||
| cast_to_remove = candidate | ||
| candidate = next( | ||
| (c for c in candidate.outputs[0].outputs if c.op == "Transpose"), | ||
| None, | ||
| ) | ||
| if candidate is None: | ||
| cast_to_remove = None | ||
| continue | ||
| if candidate.op != "Transpose": | ||
| cast_to_remove = None | ||
| continue | ||
| if any(c.op == "MatMul" for c in candidate.outputs[0].outputs): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential issue: When for candidate in list(dq_op.outputs[0].outputs):
if candidate.op == "Cast":
...
if candidate is None:
cast_to_remove = None
continue # <-- this continues to next candidate, good
if candidate.op != "Transpose":
cast_to_remove = None
continue # <-- this also continues, good
...
break # <-- but this break exits even if we didn't find a matchActually looking more carefully, the logic resets |
||
| perm = candidate.attrs.get("perm", None) | ||
| torch_weights = ( | ||
| torch_weights.permute(*perm).contiguous() | ||
| if perm is not None | ||
| else torch_weights.T.contiguous() | ||
| ) | ||
| transpose_to_remove = candidate | ||
| break | ||
|
Comment on lines
+100
to
+108
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Restrict these rewrites to pure MatMul fanout. Each transform fires when it sees any Also applies to: 128-137, 196-220, 255-275, 334-339 🤖 Prompt for AI Agents |
||
|
|
||
| # Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8. | ||
| numpy_weights = ( | ||
| (torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy() | ||
|
|
@@ -94,20 +123,232 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |
| dq_op.inputs[0] = onnx_weights_fp8 | ||
| dq_op.op = "DequantizeLinear" | ||
| dq_op.outputs[0].dtype = dq_op.inputs[1].dtype | ||
| dq_op.outputs[0].shape = list(numpy_weights.shape) | ||
|
|
||
| if transpose_to_remove is not None: | ||
| t_out = transpose_to_remove.outputs[0] | ||
| for consumer in list(t_out.outputs): | ||
| for i, inp in enumerate(consumer.inputs): | ||
| if inp is t_out: | ||
| consumer.inputs[i] = dq_op.outputs[0] | ||
| transpose_to_remove.outputs.clear() | ||
| if cast_to_remove is not None: | ||
| cast_to_remove.outputs.clear() | ||
| n_t_folded += 1 | ||
|
|
||
| graph.cleanup().toposort() | ||
| end_time = time.time() | ||
| if n_t_folded > 0: | ||
| logger.info(f"Folded {n_t_folded} weight Transpose nodes during weight compression") | ||
| print(f"fp8 qdq replaced with only dq completed in {end_time - start_time}s.") | ||
|
|
||
| return gs.export_onnx(graph) | ||
|
|
||
| @staticmethod | ||
| def _move_mul_before_qdq(graph: gs.Graph) -> int: | ||
| """Move attention-scaling Mul(const) from after DQ to before Q for TRT MatMul fusion. | ||
|
|
||
| Handles both ``DQ → Mul → MatMul`` and ``DQ → Transpose → Mul → MatMul`` (K path). | ||
| """ | ||
| count = 0 | ||
| for mul_node in list(graph.nodes): | ||
| if mul_node.op != "Mul": | ||
| continue | ||
|
|
||
| const_input = next( | ||
| (i for i in mul_node.inputs if isinstance(i, gs.Constant) and i.values.size == 1), | ||
| None, | ||
| ) | ||
| tensor_input = next( | ||
| (i for i in mul_node.inputs if not isinstance(i, gs.Constant)), None | ||
| ) | ||
| if const_input is None or tensor_input is None: | ||
| continue | ||
| if not (isinstance(tensor_input, gs.Variable) and len(tensor_input.inputs) == 1): | ||
| continue | ||
|
|
||
| producer = tensor_input.inputs[0] | ||
| transpose_node = producer if producer.op == "Transpose" else None | ||
| dq_node = producer if producer.op == "DequantizeLinear" else None | ||
| if transpose_node is not None: | ||
| t_input = transpose_node.inputs[0] | ||
| if ( | ||
| isinstance(t_input, gs.Variable) | ||
| and len(t_input.inputs) == 1 | ||
| and t_input.inputs[0].op == "DequantizeLinear" | ||
| ): | ||
| dq_node = t_input.inputs[0] | ||
| if dq_node is None: | ||
| continue | ||
|
|
||
| q_output = dq_node.inputs[0] | ||
| if ( | ||
| not isinstance(q_output, gs.Variable) | ||
| or len(q_output.inputs) != 1 | ||
| or q_output.inputs[0].op != "QuantizeLinear" | ||
| ): | ||
| continue | ||
| q_node = q_output.inputs[0] | ||
| q_input = q_node.inputs[0] | ||
| if not isinstance(q_input, gs.Variable): | ||
| continue | ||
|
|
||
| mul_output = mul_node.outputs[0] | ||
| if not any(c.op == "MatMul" for c in mul_output.outputs): | ||
| continue | ||
|
|
||
| new_mul_output = gs.Variable( | ||
| q_input.name + "_scaled", dtype=q_input.dtype, shape=q_input.shape | ||
| ) | ||
| graph.nodes.append( | ||
| gs.Node( | ||
| op="Mul", | ||
| name=mul_node.name + "_moved", | ||
| inputs=[q_input, const_input], | ||
| outputs=[new_mul_output], | ||
| ) | ||
| ) | ||
| q_node.inputs[0] = new_mul_output | ||
|
|
||
| replacement = ( | ||
| transpose_node.outputs[0] if transpose_node is not None else dq_node.outputs[0] | ||
| ) | ||
| for consumer in list(mul_output.outputs): | ||
| for i, inp in enumerate(consumer.inputs): | ||
| if inp is mul_output: | ||
| consumer.inputs[i] = replacement | ||
| mul_node.outputs.clear() | ||
| count += 1 | ||
|
|
||
| graph.cleanup().toposort() | ||
| return count | ||
|
|
||
| @staticmethod | ||
| def _move_transpose_before_qdq(graph: gs.Graph) -> int: | ||
| """Move Transpose from ``DQ → Transpose → MatMul`` to ``Transpose → Q → DQ → MatMul`` (K path).""" | ||
| count = 0 | ||
| for transpose_node in list(graph.nodes): | ||
| if transpose_node.op != "Transpose": | ||
| continue | ||
|
|
||
| t_input = transpose_node.inputs[0] | ||
| if ( | ||
| not isinstance(t_input, gs.Variable) | ||
| or len(t_input.inputs) != 1 | ||
| or t_input.inputs[0].op != "DequantizeLinear" | ||
| ): | ||
| continue | ||
| dq_node = t_input.inputs[0] | ||
|
|
||
| dq_input = dq_node.inputs[0] | ||
| if ( | ||
| not isinstance(dq_input, gs.Variable) | ||
| or len(dq_input.inputs) != 1 | ||
| or dq_input.inputs[0].op != "QuantizeLinear" | ||
| ): | ||
| continue | ||
| q_node = dq_input.inputs[0] | ||
| q_input = q_node.inputs[0] | ||
| if not isinstance(q_input, gs.Variable): | ||
| continue | ||
|
|
||
| t_output = transpose_node.outputs[0] | ||
| if not any(c.op == "MatMul" for c in t_output.outputs): | ||
| continue | ||
|
|
||
| new_t_output = gs.Variable(q_input.name + "_transposed", dtype=q_input.dtype) | ||
| graph.nodes.append( | ||
| gs.Node( | ||
| op="Transpose", | ||
| name=transpose_node.name + "_moved", | ||
| inputs=[q_input], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The hardcoded |
||
| outputs=[new_t_output], | ||
| attrs=transpose_node.attrs, | ||
| ) | ||
| ) | ||
| q_node.inputs[0] = new_t_output | ||
|
|
||
| for consumer in list(t_output.outputs): | ||
| for i, inp in enumerate(consumer.inputs): | ||
| if inp is t_output: | ||
| consumer.inputs[i] = dq_node.outputs[0] | ||
| transpose_node.outputs.clear() | ||
| count += 1 | ||
|
|
||
| graph.cleanup().toposort() | ||
| return count | ||
|
|
||
| @staticmethod | ||
| def _insert_qdq_after_softmax(graph: gs.Graph) -> int: | ||
| """Insert FP8 Q→DQ on Softmax outputs feeding MatMul (required by TRT MHA fusion). | ||
|
|
||
| Torch export does not quantize softmax output; scale=1/448 saturates exactly at 1.0 | ||
| (softmax range is [0, 1]) while covering the full FP8 E4M3 representable range. | ||
| """ | ||
| import numpy as np | ||
|
|
||
| count = 0 | ||
| for softmax_node in list(graph.nodes): | ||
| if softmax_node.op != "Softmax": | ||
| continue | ||
| softmax_output = softmax_node.outputs[0] | ||
| if not any(c.op == "MatMul" for c in softmax_output.outputs): | ||
| continue | ||
| if any(c.op == "QuantizeLinear" for c in softmax_output.outputs): | ||
| continue | ||
|
|
||
| # Match scale dtype to the graph's current float dtype so TRT stronglyTyped | ||
| # sees consistent Q/DQ types with the surrounding compute. | ||
| scale_dtype = softmax_output.dtype if softmax_output.dtype is not None else np.float32 | ||
| scale_val = np.array(1.0 / 448.0, dtype=scale_dtype) | ||
| scale_constant = gs.Constant(softmax_node.name + "/softmax_q_scale", scale_val) | ||
| dq_scale_constant = gs.Constant( | ||
| softmax_node.name + "/softmax_dq_scale", scale_val.copy() | ||
| ) | ||
|
|
||
| zp_tensor = onnx.TensorProto() | ||
| zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN | ||
| zp_tensor.dims.extend([1]) | ||
| zp_tensor.raw_data = b"\x00" | ||
| zp_constant = gs.Constant( | ||
| softmax_node.name + "/softmax_q_zero_point", LazyValues(zp_tensor) | ||
| ) | ||
|
|
||
| q_output = gs.Variable(softmax_node.name + "/q_output") | ||
| dq_output = gs.Variable(softmax_node.name + "/dq_output", dtype=softmax_output.dtype) | ||
| q_node = gs.Node( | ||
| op="QuantizeLinear", | ||
| name=softmax_node.name + "/QuantizeLinear", | ||
| inputs=[softmax_output, scale_constant, zp_constant], | ||
| outputs=[q_output], | ||
| attrs={"saturate": 1}, | ||
| ) | ||
| dq_node = gs.Node( | ||
| op="DequantizeLinear", | ||
| name=softmax_node.name + "/DequantizeLinear", | ||
| inputs=[q_output, dq_scale_constant], | ||
| outputs=[dq_output], | ||
| ) | ||
| graph.nodes.extend([q_node, dq_node]) | ||
|
|
||
| for consumer in list(softmax_output.outputs): | ||
| if consumer is q_node: | ||
| continue | ||
| for i, inp in enumerate(consumer.inputs): | ||
| if inp is softmax_output: | ||
| consumer.inputs[i] = dq_output | ||
| count += 1 | ||
|
|
||
| graph.cleanup().toposort() | ||
| return count | ||
|
|
||
| @staticmethod | ||
| def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | ||
| """Post-processes the ONNX model for FP8 quantization. | ||
|
|
||
| Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear: | ||
| - TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1 | ||
| - TRT_FP8DequantizeLinear -> DequantizeLinear | ||
| Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear and | ||
| rewrites attention scaling / K-transpose / softmax-output patterns so TRT | ||
| can fuse DQ into the attention MatMul kernels. | ||
|
|
||
| Args: | ||
| onnx_model: The ONNX model containing TRT_FP8 quantization nodes. | ||
|
|
@@ -144,5 +385,15 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |
| f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear" | ||
| ) | ||
|
|
||
| # Attention-aware rewrites so TRT can fuse DQ into the attention MatMuls. | ||
| n_mul = FP8QuantExporter._move_mul_before_qdq(graph) | ||
| n_t = FP8QuantExporter._move_transpose_before_qdq(graph) | ||
| n_sm = FP8QuantExporter._insert_qdq_after_softmax(graph) | ||
| if n_mul or n_t or n_sm: | ||
| logger.info( | ||
| f"Attention QDQ rewrites: moved {n_mul} Mul, {n_t} Transpose; " | ||
| f"inserted QDQ on {n_sm} Softmax outputs" | ||
| ) | ||
|
|
||
| graph.cleanup().toposort() | ||
| return gs.export_onnx(graph) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1415,6 +1415,96 @@ def _bypass_cast_node(model: onnx.ModelProto, node: onnx.NodeProto) -> None: | |
| consumer.input[i] = input_tensor | ||
|
|
||
|
|
||
| _DQ_OPS = {"DequantizeLinear", "TRT_FP8DequantizeLinear"} | ||
| _Q_OPS = {"QuantizeLinear", "TRT_FP8QuantizeLinear"} | ||
|
|
||
|
|
||
| def _scale_fp32_to_fp16(scale_init: onnx.TensorProto) -> None: | ||
| """Convert a scalar Q/DQ scale initializer in-place from FP32 to FP16.""" | ||
| import numpy as np | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: |
||
| if scale_init.data_type != onnx.TensorProto.FLOAT: | ||
| return | ||
| scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32) | ||
| if not scale_data.size: | ||
| scale_data = np.array(scale_init.float_data, dtype=np.float32) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential precision issue: import warnings
fp16_val = scale_data.astype(np.float16)
if np.any(np.isinf(fp16_val)) or np.any(fp16_val == 0) and np.any(scale_data != 0):
logger.warning("Q/DQ scale overflows/underflows when converting to FP16") |
||
| scale_init.data_type = onnx.TensorProto.FLOAT16 | ||
| scale_init.raw_data = scale_data.astype(np.float16).tobytes() | ||
| del scale_init.float_data[:] | ||
|
|
||
|
|
||
| def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | ||
| """Remove ``DQ → Cast(FP32→FP16)`` patterns inserted by ``convert_float_to_float16``. | ||
|
|
||
| The DQ scale is rewritten to FP16 so DQ natively produces FP16 output, enabling | ||
| TRT to fuse DQ directly into the downstream compute op. | ||
| """ | ||
| producer_map = {out: node for node in onnx_model.graph.node for out in node.output} | ||
| initializers = {init.name: init for init in onnx_model.graph.initializer} | ||
|
|
||
| to_remove = [] | ||
| for node in onnx_model.graph.node: | ||
| if node.op_type != "Cast": | ||
| continue | ||
| cast_to = next((a.i for a in node.attribute if a.name == "to"), None) | ||
| if cast_to != onnx.TensorProto.FLOAT16: | ||
| continue | ||
| producer = producer_map.get(node.input[0]) | ||
| if producer is None or producer.op_type not in _DQ_OPS: | ||
| continue | ||
|
|
||
| if len(producer.input) >= 2 and producer.input[1] in initializers: | ||
| _scale_fp32_to_fp16(initializers[producer.input[1]]) | ||
|
|
||
| _bypass_cast_node(onnx_model, node) | ||
| to_remove.append(node) | ||
|
|
||
| for vi in onnx_model.graph.value_info: | ||
| if vi.name == producer.output[0]: | ||
| vi.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16 | ||
| break | ||
|
|
||
| logger.debug(f"Folded {len(to_remove)} DQ -> Cast(FP32->FP16) patterns") | ||
| for node in to_remove: | ||
| onnx_model.graph.node.remove(node) | ||
| return onnx_model | ||
|
|
||
|
|
||
| def fold_q_fp16_to_fp32_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | ||
| """Remove ``Cast(FP16→FP32) → Q`` patterns inserted by ``convert_float_to_float16``. | ||
|
|
||
| The Q scale is rewritten to FP16 so Q consumes the FP16 graph directly. | ||
| """ | ||
| consumer_map: dict[str, list[onnx.NodeProto]] = {} | ||
| for node in onnx_model.graph.node: | ||
| for inp in node.input: | ||
| consumer_map.setdefault(inp, []).append(node) | ||
| initializers = {init.name: init for init in onnx_model.graph.initializer} | ||
|
|
||
| to_remove = [] | ||
| for node in onnx_model.graph.node: | ||
| if node.op_type != "Cast": | ||
| continue | ||
| cast_to = next((a.i for a in node.attribute if a.name == "to"), None) | ||
| if cast_to != onnx.TensorProto.FLOAT: | ||
| continue | ||
| consumers = consumer_map.get(node.output[0], []) | ||
| if not consumers or not all(c.op_type in _Q_OPS for c in consumers): | ||
| continue | ||
|
|
||
| for q_node in consumers: | ||
| if len(q_node.input) >= 2 and q_node.input[1] in initializers: | ||
| _scale_fp32_to_fp16(initializers[q_node.input[1]]) | ||
|
|
||
| _bypass_cast_node(onnx_model, node) | ||
| to_remove.append(node) | ||
|
|
||
| logger.debug(f"Folded {len(to_remove)} Cast(FP16->FP32) -> Q patterns") | ||
| for node in to_remove: | ||
| onnx_model.graph.node.remove(node) | ||
| return onnx_model | ||
|
Comment on lines
+1422
to
+1505
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: The minimum ONNX opset version that allows FLOAT16 scale tensors for both QuantizeLinear and DequantizeLinear is 19. Citations:
🏁 Script executed: # Check for BASE_MIN_OPSET and get_opset_version in the file
rg -n "BASE_MIN_OPSET|get_opset_version" modelopt/onnx/utils.py | head -20Repository: NVIDIA/Model-Optimizer Length of output: 335 🏁 Script executed: # Check the actual implementation of fold_dq_fp32_to_fp16_casts and fold_q_fp16_to_fp32_casts for guards
sed -n '1422,1505p' modelopt/onnx/utils.py | cat -nRepository: NVIDIA/Model-Optimizer Length of output: 4025 🏁 Script executed: # Search for where these fold functions are called
rg -n "fold_dq_fp32_to_fp16_casts|fold_q_fp16_to_fp32_casts" modelopt/onnx/utils.pyRepository: NVIDIA/Model-Optimizer Length of output: 234 🏁 Script executed: # Look for opset version checks near the fold function calls
rg -B 10 -A 10 "fold_dq_fp32_to_fp16_casts|fold_q_fp16_to_fp32_casts" modelopt/onnx/utils.py | head -50Repository: NVIDIA/Model-Optimizer Length of output: 1836 🏁 Script executed: # Search for calls to these fold functions throughout the codebase
rg -n "fold_dq_fp32_to_fp16_casts|fold_q_fp16_to_fp32_casts" modelopt/ --type py -B 5 -A 5Repository: NVIDIA/Model-Optimizer Length of output: 3546 🏁 Script executed: # Also check for the convert_float_to_float16 function mentioned in the docstring
rg -n "convert_float_to_float16" modelopt/onnx/utils.py -B 5 -A 20Repository: NVIDIA/Model-Optimizer Length of output: 2454 🏁 Script executed: # Check the context where fold functions are called in torch_onnx.py
sed -n '610,640p' modelopt/torch/_deploy/utils/torch_onnx.py | cat -nRepository: NVIDIA/Model-Optimizer Length of output: 1714 🏁 Script executed: # Look for opset validation in the export_to_onnx function or its caller
grep -n "def export_to_onnx\|opset\|BASE_MIN_OPSET" modelopt/torch/_deploy/utils/torch_onnx.py | head -30Repository: NVIDIA/Model-Optimizer Length of output: 232 🏁 Script executed: # Check what the is_fp8_quantized function does
rg -n "def is_fp8_quantized" modelopt/ -A 10 --type pyRepository: NVIDIA/Model-Optimizer Length of output: 1147 🏁 Script executed: # Check the convert_float_to_float16 function and how it's used
rg -n "def convert_float_to_float16\|convert_float_to_float16" modelopt/onnx/utils.py -B 3 -A 15 | head -60Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Check for any opset validation in export_to_onnx before the fold calls
sed -n '450,630p' modelopt/torch/_deploy/utils/torch_onnx.py | grep -n "opset\|get_opset_version\|BASE_MIN_OPSET"Repository: NVIDIA/Model-Optimizer Length of output: 230 🏁 Script executed: # Look at DEFAULT_ONNX_OPSET to understand the default opset value
rg -n "DEFAULT_ONNX_OPSET" modelopt/ --type pyRepository: NVIDIA/Model-Optimizer Length of output: 459 🏁 Script executed: # Search for convert_float_to_float16 function across the entire codebase
rg -n "convert_float_to_float16" modelopt/ --type pyRepository: NVIDIA/Model-Optimizer Length of output: 506 🏁 Script executed: # Check what the default export behavior does - look at complete export_to_onnx signature and opset handling
sed -n '450,580p' modelopt/torch/_deploy/utils/torch_onnx.py | cat -nRepository: NVIDIA/Model-Optimizer Length of output: 6658 🏁 Script executed: # Check the entire export_to_onnx function flow to see if there's any opset validation before fold calls
sed -n '460,650p' modelopt/torch/_deploy/utils/torch_onnx.py | cat -nRepository: NVIDIA/Model-Optimizer Length of output: 9684 🏁 Script executed: # Search for any opset validation or assertion related to BASE_MIN_OPSET in the export path
rg -n "BASE_MIN_OPSET" modelopt/torch/_deploy/ --type pyRepository: NVIDIA/Model-Optimizer Length of output: 48 Add opset guard to FP16 scale rewrite helpers. These helpers rewrite Q/DQ scale initializers to 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def _is_foldable_constant_cast_pattern(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: | ||
| """Check if a Constant -> Cast pattern can be folded.""" | ||
| assert node.op_type == "Cast" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you update changelog for 0.45 to mention this new feature details?