Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/nxp/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ fbcode_target(_kind = runtime.python_library,
deps = [
":aten_passes",
"//caffe2:torch",
"//executorch/backends/transforms:quantize_fused_convbn_bias_pass",
"//pytorch/ao:torchao", # @manual
],
)
Expand Down
13 changes: 8 additions & 5 deletions backends/nxp/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
prepare_pt2e,
prepare_qat_pt2e,
)
from executorch.backends.transforms.quantize_fused_convbn_bias_pass import (
QuantizeFusedConvBnBiasAtenPass,
)

from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY, Quantizer


Expand Down Expand Up @@ -162,14 +166,13 @@ def find_sequential_partitions_aten(


def calibrate_and_quantize(
model: ExportedProgram | fx.GraphModule,
model: ExportedProgram,
calibration_inputs: Iterable[tuple[torch.Tensor, ...]],
quantizer: Quantizer,
is_qat: bool = False,
) -> fx.GraphModule:
"""Quantize the provided model.

:param model: Aten model (or it's GraphModule representation) to quantize.
:param model: Aten exported model to quantize.
:param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model
input. Or an iterator over such tuples.
:param quantizer: Quantizer to use.
Expand All @@ -179,8 +182,7 @@ def calibrate_and_quantize(
:return: Quantized GraphModule.
"""

if isinstance(model, ExportedProgram):
model = model.module()
model = model.module()

if is_qat:
m = prepare_qat_pt2e(model, quantizer)
Expand All @@ -192,4 +194,5 @@ def calibrate_and_quantize(
m(*data)
m = convert_pt2e(m)

QuantizeFusedConvBnBiasAtenPass()(m)
return m
15 changes: 15 additions & 0 deletions backends/nxp/tests/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,18 @@ fbcode_target(_kind = python_pytest,
":models",
]
)

fbcode_target(_kind = python_pytest,
name = "test_batch_norm_fusion",
srcs = [
"test_batch_norm_fusion.py",
],
deps = [
"//caffe2:torch",
"//executorch/backends/nxp:neutron_backend",
":executorch_pipeline",
":models",
"fbsource//third-party/pypi/pytest:pytest",
"fbsource//third-party/pypi/numpy:numpy",
],
)
16 changes: 13 additions & 3 deletions backends/nxp/tests/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import warnings
from typing import Callable, Dict, Union

Expand Down Expand Up @@ -36,7 +38,10 @@
try:
import tensorflow.lite as tflite
except ModuleNotFoundError:
import tflite_runtime.interpreter as tflite
try:
import tflite_runtime.interpreter as tflite
except ModuleNotFoundError:
tflite = None


class EdgeProgramExecutor:
Expand Down Expand Up @@ -85,7 +90,7 @@ def __init__(
saved_model_name="model.tflite",
delegate_path=None,
num_threads=None,
op_resolver_type=tflite.experimental.OpResolverType.AUTO,
op_resolver_type=None,
):
"""
Construct TFLiteExecutor used to quickly run inference on TFLite model.
Expand All @@ -105,6 +110,8 @@ def __init__(
https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter for details. Default value is
tflite.experimental.OpResolverType.AUTO.
"""
if op_resolver_type is None:
op_resolver_type = tflite.experimental.OpResolverType.AUTO
assert model_path is not None or model_content is not None
assert model_path is None or model_content is None

Expand Down Expand Up @@ -310,9 +317,12 @@ def convert_run_compare(
tflite_input_preprocess: TFLiteIOPreprocess = TFLiteIOPreprocess(), # noqa B008
tflite_output_preprocess: TFLiteIOPreprocess = TFLiteIOPreprocess(), # noqa B008
conversion_config: ConversionConfig = ConversionConfig(), # noqa B008
tflite_op_resolver_type=tflite.experimental.OpResolverType.AUTO,
tflite_op_resolver_type=None,
) -> (TFLiteExecutor, EdgeProgramExecutor):

if tflite_op_resolver_type is None:
tflite_op_resolver_type = tflite.experimental.OpResolverType.AUTO

if tfl_model is None:
NodeFormatInference(edge_program).identify_node_formats()
tfl_model, _ = EdgeProgramToIRConverter().convert_program(
Expand Down
23 changes: 23 additions & 0 deletions backends/nxp/tests/test_batch_norm_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from executorch.backends.nxp.tests.executors import OverrideTargetSupportCheck
from torch import nn

from executorch.backends.nxp.tests.models import ConvBNModule


@pytest.fixture(autouse=True)
def reseed_model_per_test_run():
Expand Down Expand Up @@ -231,3 +233,24 @@ def unsupported_target(*_): # Accept all input arguments and return `False`.
node.op == "call_function" and "batch_norm" in node.target.__name__
for node in nodes
)
@pytest.mark.parametrize(
"conv_module",
["conv2d"],
)
def test_biasless_convbn_fusion_qat(
conv_module,
):
if conv_module.startswith("conv1d"):
input_shape = (1, 3, 32)
elif conv_module.startswith("conv2d"):
input_shape = (1, 3, 32, 32)
else: # conv3d
input_shape = (1, 3, 32, 32, 32)

model = ConvBNModule(conv_module, conv_bias=False, bn_affine=True)

edge_program = to_quantized_edge_program(
model, input_shape, use_qat=True, use_neutron_for_format_conversion=False
).exported_program()

assert any("lowered_module" in node.name for node in edge_program.graph.nodes)
23 changes: 23 additions & 0 deletions backends/transforms/BUCK
Original file line number Diff line number Diff line change
@@ -1,6 +1,29 @@
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
load("@fbcode_macros//build_defs:python_pytest.bzl", "python_pytest")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

fbcode_target(_kind = define_common_targets,)

fbcode_target(_kind = python_pytest,
name = "test_quantize_fused_convbn_bias_pass",
srcs = [
"test/test_quantize_fused_convbn_bias_pass.py",
],
deps = [
"//caffe2:torch",
":quantize_fused_convbn_bias_pass",
"//executorch/backends/arm/quantizer:arm_quantizer",
"//executorch/backends/arm/test:arm_tester_lib",
"//executorch/backends/arm/test:arm_tester_serialize",
"//executorch/backends/arm/test:common",
"//executorch/backends/arm/tosa:tosa",
"//executorch/backends/nxp:quantizer",
"//executorch/backends/nxp:neutron_backend",
"//executorch/backends/xnnpack/test/tester:tester",
"//executorch/exir:lib",
"//executorch/kernels/quantized:custom_ops_generated_lib",
"fbsource//third-party/pypi/pytest:pytest",
],
)
Loading
Loading