Skip to content

Comments

fix(MemoryFormatOpsPass): preserve input dim_order for clone/to_copy with no memory_format kwarg#17611

Open
nefainl wants to merge 1 commit intopytorch:mainfrom
nefainl:fix/16032-memory-format-ops-pass-preserve-format
Open

fix(MemoryFormatOpsPass): preserve input dim_order for clone/to_copy with no memory_format kwarg#17611
nefainl wants to merge 1 commit intopytorch:mainfrom
nefainl:fix/16032-memory-format-ops-pass-preserve-format

Conversation

@nefainl
Copy link

@nefainl nefainl commented Feb 21, 2026

Summary

Fixes #16032

This PR fixes MemoryFormatOpsPass to correctly handle torch.preserve_format semantics for clone() and _to_copy.default operations.

Root cause: When clone() or _to_copy is called without an explicit memory_format kwarg, the pass was defaulting to torch.contiguous_format, causing the output dim_order to be [0,1,2,3] (contiguous) even when the input was channels-last [0,2,3,1]. This caused runtime assertion failures:

Code=18 InvalidArgument: tensors_have_same_dim_order(self, out)

Fix: Change the default from torch.contiguous_format to torch.preserve_format, and derive dim_order from the input tensor's dim_order() when preserve_format is used.

This is a minimal, focused fix following the guidance from @GregoryComer in the discussion on PR #17463.

Changes

  • exir/passes/memory_format_ops_pass.py (+29/-5 lines):

    • Default memory_format to torch.preserve_format instead of torch.contiguous_format
    • When preserve_format, derive dim_order from input_tensor.dim_order()
    • Fallback to contiguous if no input tensor available (e.g., empty())
  • exir/tests/test_passes.py (+130 lines):

Standalone Reproduction

import torch
from torch.export import export
from executorch.exir import to_edge, EdgeCompileConfig

class ConvClone(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1)

    def forward(self, x):
        return self.conv(x).clone()

model = ConvClone().to(memory_format=torch.channels_last)
x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last)

exported = export(model, (x,))
edge = to_edge(exported, compile_config=EdgeCompileConfig(_skip_dim_order=False))

# Before fix: clone node has dim_order=(0,1,2,3) - BUG
# After fix: clone node has dim_order=(0,2,3,1) - CORRECT
for node in edge.exported_program().graph_module.graph.nodes:
    if "_clone_dim_order" in str(node.target):
        print(f"clone dim_order: {tuple(node.meta['val'].dim_order())}")

Test Plan

  • All 3 new tests pass
  • Verified fix with standalone reproduction script
  • No changes to existing tests required

Related

…with no memory_format kwarg

Issue pytorch#16032: clone() and _to_copy operations with no explicit memory_format
kwarg were defaulting to contiguous dim_order, causing runtime assertion
failures when cloning channels-last tensors.

Changes:
- Default memory_format to torch.preserve_format instead of torch.contiguous_format
- When preserve_format, derive dim_order from input tensor's dim_order()
- Simplify type annotation: dim_order is always assigned, no Optional needed

Tests:
- test_clone_no_kwarg_preserves_channels_last_dim_order: core repro case
- test_clone_contiguous_format_kwarg_stays_contiguous: regression guard
- test_to_copy_no_kwarg_preserves_channels_last_dim_order: _to_copy path
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 21, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17611

Note: Links to docs will display an error until the docs builds have been completed.

⚠️ 8 Awaiting Approval

As of commit 990608e with merge base 9a58ce8 (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 21, 2026
@nefainl
Copy link
Author

nefainl commented Feb 21, 2026

@pytorchbot label "release notes: exir"

@pytorch-bot pytorch-bot bot added the release notes: exir Changes to any dialects and passes on these dialects, such as memory planning label Feb 21, 2026
@nefainl
Copy link
Author

nefainl commented Feb 21, 2026

Request to the reviewers: Could you please add @GregoryComer as well? I have received very helpful advice from him in the other PR and it makes sense to have him also here so he can edit if needed. Note that some of the other changes will still be handled in another PR (where it also makes sense to have him review).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: exir Changes to any dialects and passes on these dialects, such as memory planning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Dim Order Validation Inconsistency for Edge / Ambiguous Cases

1 participant