Skip to content

[PyTorch] Fix bug with PR 2677#2819

Merged
sudhakarsingh27 merged 37 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn
Apr 2, 2026
Merged

[PyTorch] Fix bug with PR 2677#2819
sudhakarsingh27 merged 37 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator

Description

#2677 didn't run L1 tests and the downstream CIs failed since there was a bug where padding tokens were not ignored while returning max_logit. This PR fixes it.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • When THD and return_max_logit=True, ignore the padding tokens correctly before calculating max_logit

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

sudhakarsingh27 and others added 30 commits February 12, 2026 13:12
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
…27/TransformerEngine into fix_return_stats_max_cudnn
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 31, 2026

Greptile Summary

This PR fixes a bug introduced in #2677 where padding tokens were not excluded from the max_logit computation when using THD format with newer cuDNN (>9.6) runtimes on non-sm120 hardware. In that configuration max_tensor has shape [tq, h, 1] — a flat layout where all sequences are concatenated — rather than the batched [b, h, sq, 1] shape handled by the pre-existing code path.

Key changes:

  • The outer condition is widened from qkv_format == \"thd\" and max_tensor.ndim == 4 to qkv_format == \"thd\" with two nested branches, preserving the existing ndim == 4 logic unchanged.
  • A new ndim == 3 branch is added: when cu_seqlens_q_padded is not None (padding between sequences exists), it builds a boolean validity mask via torch.repeat_interleave that marks each padding token as invalid, then fills those positions with -inf before torch.amax.
  • The cu_seqlens_q_padded is None case (no inter-sequence padding) correctly falls through without masking — there are no padding tokens to exclude.
  • Both actual_seqlens and padded_seqlens are properly moved to max_tensor.device, addressing the device-mismatch concern raised in the previous review.
  • The previous AssertionError on cu_seqlens_q_padded=None is replaced by a safe no-op guard.

Confidence Score: 5/5

Safe to merge — the fix is logically correct, both prior review concerns are resolved, and the only remaining note is a cosmetic no-op.

The ndim == 3 mask logic is correct: counts sum to tq (verified by construction from cu_seqlens_q_padded), repeat_interleave correctly handles zero-length pads, device placement is consistent, and the cu_seqlens_q_padded is None guard avoids the crash reported in the prior review. No P0/P1 issues remain; only a trivial redundant .to() call at P2.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/cpp_extensions/fused_attn.py Extends the padding-mask logic for return_max_logit=True in THD format to cover the newer cuDNN (>9.6) non-sm120 case where max_tensor is [tq, h, 1]; correctly guards on cu_seqlens_q_padded is not None, properly moves tensors to max_tensor.device, and uses repeat_interleave to build an accurate valid-token mask.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[return_max_logit=True] --> B{qkv_format == 'thd'?}
    B -- No --> G[amax over all dims → max_logit]
    B -- Yes --> C{max_tensor.ndim}
    C -- 4\ncuDNN ≤9.6 or sm120\nb x h x sq x 1 --> D[Build sq_idx mask from cu_seqlens_q\nmasked_fill padding positions with -inf]
    C -- 3\ncuDNN >9.6 non-sm120\ntq x h x 1 --> E{cu_seqlens_q_padded\nis not None?}
    E -- Yes\npad between seqs --> F[Compute actual & padded seqlens\nrepeat_interleave → valid mask\nmasked_fill padding with -inf]
    E -- No\nno padding --> G
    D --> G
    F --> G
Loading

Reviews (3): Last reviewed commit: "fixes from feedback" | Re-trigger Greptile

sudhakarsingh27 and others added 4 commits March 31, 2026 14:56
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

elif max_tensor.ndim == 3:
if cu_seqlens_q_padded is not None:
# For THD on newer cuDNN runtimes (non-sm120), Max is [tq, h, 1] with
# padded positions containing junk. Mask them out with -inf.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe point out that this is for THD + pad_between_seqs=True + non-sm120 + cuDNN>9.6 situations?

# Expand: T×3, F×1, T×3, F×1, T×2, F×2, T×7, F×1 → TTTF|TTTF|TTFF|TTTTTTTF
valid = torch.repeat_interleave(values, counts)
# Finally, replace invalid (F) positions with -inf
max_tensor = max_tensor.masked_fill(~valid.view(-1, 1, 1), float("-inf"))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the logic (both in ndim=3 branch and ndim=4) is a bit convoluted - it would generate quite a few small kernels. I don't know if there's a more elegant way to program this, or simply zeroing out the entire Max at initialization would be more efficient performance-wise. I'm going to approve for the functionality given that this is also kind of a niche case (THD + return_max_logit=True + pad_between_seqs=True).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have control over stats/max tensors? I thought they're passed over by cuDNN.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 merged commit b048869 into NVIDIA:main Apr 2, 2026
10 of 13 checks passed
KshitijLakhani pushed a commit that referenced this pull request Apr 3, 2026
* cudnn now returns Stats always and Max only with `return_max_logit=true`

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fix a typo that caused a bug

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* update doc strings

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix more docs

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fixes from the feedback

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* update cudnn-frontend to v1.19.1

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* update the cudnn frontend

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fix a wrong omission

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bugfix: mask out padding tokens when THD

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes from greptile feedback

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor nit

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fixes from feedback

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

---------

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants