diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 7653296c78..06bfb6ef3c 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -363,13 +363,38 @@ def fused_attn_fwd( max_tensor = output_tensors[2] amax_dims = (0, 2) if max_tensor.ndim == 3 else (0, 2, 3) - if qkv_format == "thd" and max_tensor.ndim == 4: - # For THD on older cuDNN runtimes or THD on sm120, stats can be [b, h, sq, 1] with padded - # sequence positions. Exclude those padded positions when computing max_logit. - seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(device=max_tensor.device) - sq_idx = torch.arange(max_tensor.shape[2], device=max_tensor.device).view(1, 1, -1, 1) - valid = sq_idx < seqlens_q.view(-1, 1, 1, 1) - max_tensor = max_tensor.masked_fill(~valid, float("-inf")) + if qkv_format == "thd": + if max_tensor.ndim == 4: + # For THD on cuDNN <= 9.6 or THD on sm120, Max tensor can be [b, h, sq, 1] + # with padded sequence positions. Exclude those padded positions when computing max_logit. + seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(device=max_tensor.device) + sq_idx = torch.arange(max_tensor.shape[2], device=max_tensor.device).view( + 1, 1, -1, 1 + ) + valid = sq_idx < seqlens_q.view(-1, 1, 1, 1) + max_tensor = max_tensor.masked_fill(~valid, float("-inf")) + elif max_tensor.ndim == 3: + if cu_seqlens_q_padded is not None: + # For THD + pad_between_seqs=True + non-sm120 + cuDNN>9.6, Max tensor is [tq, h, 1] + # and padding positions could be uninitialized. Exclude those padded positions when + # computing max_logit. + actual_seqlens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to( + device=max_tensor.device + ) + padded_seqlens = (cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]).to( + device=max_tensor.device + ) + pad_lens = (padded_seqlens - actual_seqlens).to(device=max_tensor.device) + b = pad_lens.shape[0] + + # Stack [actual, pad] per batch into counts: e.g. [3,1, 3,1, 2,2, 7,1] + counts = torch.stack([actual_seqlens, pad_lens], dim=1).flatten() + # Tile [T, F] per sequence: [T,F, T,F, T,F, T,F] + values = torch.tensor([True, False], device=max_tensor.device).repeat(b) + # Expand: T×3, F×1, T×3, F×1, T×2, F×2, T×7, F×1 → TTTF|TTTF|TTFF|TTTTTTTF + valid = torch.repeat_interleave(values, counts) + # Finally, replace invalid (F) positions with -inf + max_tensor = max_tensor.masked_fill(~valid.view(-1, 1, 1), float("-inf")) # Max -> max_logit [h] max_logit = torch.amax(max_tensor, dim=amax_dims).to(dtype=output_tensors[0].dtype)