diff --git a/backends/arm/test/models/test_inception_v3_arm.py b/backends/arm/test/models/test_inception_v3_arm.py index 37fef566c9d..f842ea1f265 100644 --- a/backends/arm/test/models/test_inception_v3_arm.py +++ b/backends/arm/test/models/test_inception_v3_arm.py @@ -24,6 +24,9 @@ ic3 = models.inception_v3(weights=models.Inception_V3_Weights) ic3 = ic3.eval() +ic3_fp16 = models.inception_v3(weights=models.Inception_V3_Weights).to(torch.float16) +ic3_fp16 = ic3_fp16.eval() + # Normalization values referenced from here: # https://docs.pytorch.org/vision/main/models/generated/torchvision.models.quantization.inception_v3.html normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) @@ -44,6 +47,20 @@ def test_ic3_tosa_FP(): pipeline.run() +@pytest.mark.slow +def test_ic3_tosa_FP_fp16(): + inputs_fp16 = tuple(t.to(torch.float16) for t in model_inputs) + pipeline = TosaPipelineFP[input_t]( + ic3_fp16, + inputs_fp16, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + atol=5e-2, + ) + pipeline.run() + + @pytest.mark.slow def test_ic3_tosa_INT(): pipeline = TosaPipelineINT[input_t]( diff --git a/backends/arm/test/models/test_mobilenet_v3_arm.py b/backends/arm/test/models/test_mobilenet_v3_arm.py index c454d99befb..eccdc839e62 100644 --- a/backends/arm/test/models/test_mobilenet_v3_arm.py +++ b/backends/arm/test/models/test_mobilenet_v3_arm.py @@ -24,6 +24,11 @@ mv3 = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights) mv3 = mv3.eval() +mv3_fp16 = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights).to( + torch.float16 +) +mv3_fp16 = mv3_fp16.eval() + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) input_tensor = torch.rand(1, 3, 232, 232) @@ -40,6 +45,20 @@ def test_mv3_tosa_FP(): pipeline.run() +@pytest.mark.slow +def test_mv3_tosa_FP_fp16(): + inputs_fp16 = tuple(t.to(torch.float16) for t in model_inputs) + pipeline = TosaPipelineFP[input_t]( + mv3_fp16, + inputs_fp16, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + atol=5e-2, + ) + pipeline.run() + + @pytest.mark.slow def test_mv3_tosa_INT(): pipeline = TosaPipelineINT[input_t](