Skip to content
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
3fb19fc
cudnn now returns Stats always and Max only with `return_max_logit=true`
sudhakarsingh27 Feb 12, 2026
5b40701
Merge branch 'main' of github.com:NVIDIA/TransformerEngine
sudhakarsingh27 Feb 12, 2026
5d479ad
fix a typo that caused a bug
sudhakarsingh27 Feb 12, 2026
296fb9f
update doc strings
sudhakarsingh27 Feb 12, 2026
24bfd45
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2026
fd42feb
fix more docs
sudhakarsingh27 Feb 13, 2026
2d7b51b
Merge branch 'fix_return_stats_max_cudnn' of github.com:sudhakarsingh…
sudhakarsingh27 Feb 13, 2026
260380b
Merge branch 'main' into fix_return_stats_max_cudnn
sudhakarsingh27 Feb 13, 2026
7a5ab35
Merge branch 'main' into fix_return_stats_max_cudnn
sudhakarsingh27 Feb 17, 2026
9710810
Merge branch 'main' into fix_return_stats_max_cudnn
sudhakarsingh27 Feb 18, 2026
07db752
Merge branch 'main' into fix_return_stats_max_cudnn
sudhakarsingh27 Feb 19, 2026
f8b1a68
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fix_r…
sudhakarsingh27 Feb 20, 2026
b5b2b9d
Merge branch 'fix_return_stats_max_cudnn' of github.com:sudhakarsingh…
sudhakarsingh27 Feb 20, 2026
8f40cab
fixes from the feedback
sudhakarsingh27 Feb 20, 2026
56e46fd
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fix_r…
sudhakarsingh27 Feb 20, 2026
1102738
Merge branch 'main' into fix_return_stats_max_cudnn
sudhakarsingh27 Feb 23, 2026
7363541
merge main
sudhakarsingh27 Feb 26, 2026
8c0d6a1
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fix_r…
sudhakarsingh27 Mar 10, 2026
d517a13
update cudnn-frontend to v1.19.1
sudhakarsingh27 Mar 10, 2026
e005455
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fix_r…
sudhakarsingh27 Mar 10, 2026
3ae0a34
update the cudnn frontend
sudhakarsingh27 Mar 12, 2026
ef0d7ec
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fix_r…
sudhakarsingh27 Mar 12, 2026
c7bc0a3
resolve conflicts
sudhakarsingh27 Mar 23, 2026
116e24a
Merge branch 'fix_return_stats_max_cudnn' of github.com:sudhakarsingh…
sudhakarsingh27 Mar 23, 2026
baa96ff
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fix_r…
sudhakarsingh27 Mar 24, 2026
7122901
fix a wrong omission
sudhakarsingh27 Mar 24, 2026
aa9b311
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2026
799c6d9
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fix_r…
sudhakarsingh27 Mar 24, 2026
599d68f
Merge branch 'fix_return_stats_max_cudnn' of github.com:sudhakarsingh…
sudhakarsingh27 Mar 24, 2026
1696f7f
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fix_r…
sudhakarsingh27 Mar 31, 2026
2c52eb3
bugfix: mask out padding tokens when THD
sudhakarsingh27 Mar 31, 2026
58b8e82
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2026
6c6aec4
fixes from greptile feedback
sudhakarsingh27 Mar 31, 2026
0deaae2
merge from upstream pull
sudhakarsingh27 Mar 31, 2026
757c2de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2026
dc7968e
minor nit
sudhakarsingh27 Mar 31, 2026
7ef541c
fixes from feedback
sudhakarsingh27 Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions transformer_engine/pytorch/cpp_extensions/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,38 @@ def fused_attn_fwd(
max_tensor = output_tensors[2]
amax_dims = (0, 2) if max_tensor.ndim == 3 else (0, 2, 3)

if qkv_format == "thd" and max_tensor.ndim == 4:
# For THD on older cuDNN runtimes or THD on sm120, stats can be [b, h, sq, 1] with padded
# sequence positions. Exclude those padded positions when computing max_logit.
seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(device=max_tensor.device)
sq_idx = torch.arange(max_tensor.shape[2], device=max_tensor.device).view(1, 1, -1, 1)
valid = sq_idx < seqlens_q.view(-1, 1, 1, 1)
max_tensor = max_tensor.masked_fill(~valid, float("-inf"))
if qkv_format == "thd":
if max_tensor.ndim == 4:
# For THD on cuDNN <= 9.6 or THD on sm120, Max tensor can be [b, h, sq, 1]
# with padded sequence positions. Exclude those padded positions when computing max_logit.
seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(device=max_tensor.device)
sq_idx = torch.arange(max_tensor.shape[2], device=max_tensor.device).view(
1, 1, -1, 1
)
valid = sq_idx < seqlens_q.view(-1, 1, 1, 1)
max_tensor = max_tensor.masked_fill(~valid, float("-inf"))
elif max_tensor.ndim == 3:
if cu_seqlens_q_padded is not None:
# For THD + pad_between_seqs=True + non-sm120 + cuDNN>9.6, Max tensor is [tq, h, 1]
# and padding positions could be uninitialized. Exclude those padded positions when
# computing max_logit.
actual_seqlens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(
device=max_tensor.device
)
padded_seqlens = (cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]).to(
device=max_tensor.device
)
pad_lens = (padded_seqlens - actual_seqlens).to(device=max_tensor.device)
b = pad_lens.shape[0]

# Stack [actual, pad] per batch into counts: e.g. [3,1, 3,1, 2,2, 7,1]
counts = torch.stack([actual_seqlens, pad_lens], dim=1).flatten()
# Tile [T, F] per sequence: [T,F, T,F, T,F, T,F]
values = torch.tensor([True, False], device=max_tensor.device).repeat(b)
# Expand: T×3, F×1, T×3, F×1, T×2, F×2, T×7, F×1 → TTTF|TTTF|TTFF|TTTTTTTF
valid = torch.repeat_interleave(values, counts)
# Finally, replace invalid (F) positions with -inf
max_tensor = max_tensor.masked_fill(~valid.view(-1, 1, 1), float("-inf"))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the logic (both in ndim=3 branch and ndim=4) is a bit convoluted - it would generate quite a few small kernels. I don't know if there's a more elegant way to program this, or simply zeroing out the entire Max at initialization would be more efficient performance-wise. I'm going to approve for the functionality given that this is also kind of a niche case (THD + return_max_logit=True + pad_between_seqs=True).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have control over stats/max tensors? I thought they're passed over by cuDNN.


# Max -> max_logit [h]
max_logit = torch.amax(max_tensor, dim=amax_dims).to(dtype=output_tensors[0].dtype)
Expand Down
Loading