Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
fa2e2cb
Enable sm120 support for fused attn if cuDNN is 9.18.1+
KshitijLakhani Feb 20, 2026
bea8bbb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2026
b2f5864
Force intermediate tensors such as S, Sum_Exp, and Max to be BHS1 sha…
KshitijLakhani Mar 2, 2026
076420d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 2, 2026
8753fc2
Add support for sm120 correct batch, seq dims
KshitijLakhani Mar 4, 2026
0336d2a
Add support for sm120 BHS1 style max logit even QKV are THD to avoid …
KshitijLakhani Mar 11, 2026
d24eb35
Disable fused and flash attn for sm120 filter:kv cache
KshitijLakhani Mar 11, 2026
e2e89d4
For CP P2P attn, set softmax_lse_in_packed_format to False if sm120+
KshitijLakhani Mar 11, 2026
5a8ecb9
Assert in TE if T3HD/TH3D layout is used on sm120 before cuDNN F16 sd…
KshitijLakhani Mar 11, 2026
7ca7564
Modify is_ragged_q && cudnn_runtime_version >= 90600 check to also in…
KshitijLakhani Mar 11, 2026
3227016
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2026
46c6e60
nit: Code clean up
KshitijLakhani Mar 11, 2026
b75348c
Disable fused attn for T3HD and TH3D
KshitijLakhani Mar 12, 2026
75ef6d9
Merge branch 'main' into klakhani/maint/sm120-thd-flash-support
KshitijLakhani Mar 12, 2026
bcfef90
nit: Add missed sm120 guard
KshitijLakhani Mar 13, 2026
577a352
Modify sm120 condition to be very specific to sm120 and not generaliz…
KshitijLakhani Mar 19, 2026
9c0a56b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2026
0d2af20
nit: Fix missing sm120 check in fwd
KshitijLakhani Mar 19, 2026
fa59e0b
Move the check for sm120 T3HD/TH3D to nvte_get_fused_attn_backend() i…
KshitijLakhani Mar 20, 2026
cb17eb3
nit: Check for matching sm120 and not sm120+
KshitijLakhani Mar 20, 2026
22f020c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,23 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
"Please upgrade your cuDNN version if possible."
<< std::endl;
}
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen && sm_arch_ == 120) {
if (cudnn_runtime_version < 91801) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: Given combination of sm_arch_ == 120 and cudnn_runtime_version < "
"91801 is not supported. "
<< " Please upgrade your cuDNN version if possible." << std::endl;
} else {
// Known missing support for T3HD/TH3D layouts on SM120
const bool is_t3hd_or_th3d =
(qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D);
if (is_t3hd_or_th3d) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: Given combination of T3HD/TH3D layouts on SM120 is not supported. "
<< " Please consider using other THD layouts if possible." << std::endl;
}
}
}
} else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD);
const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120;

NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
Expand All @@ -96,11 +99,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t actual_b = b;
if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv;
// On SM 120, cuDNN support check treats layouts with stride[0] > dim[1]*dim[2]*dim[3]
// as interleaved and rejects them. Use BHSD-like dimensions/strides with max_seqlen at plan build
// so the check passes; ragged offset still provides variable-length boundaries.
if (sm_arch_ != 120) {
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv;
}
}

const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
Expand Down Expand Up @@ -336,7 +344,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
}

std::shared_ptr<fe::graph::Tensor_attributes> Max, Sum_Exp;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
offset_stats =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_stats")
Expand All @@ -353,7 +361,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_name("Sum_Exp")
.set_dim({b, h, s_q, 1})
.set_data_type(fe::DataType_t::FLOAT));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
Max->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
Sum_Exp->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
} else {
Expand Down Expand Up @@ -381,7 +389,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(

if (!return_max_logit) {
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1});
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
} else {
Stats->set_stride({h * s_q, s_q, 1, 1});
Expand All @@ -407,9 +415,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr);
auto offset_kv_tuple =
is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr);
auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600)
? std::make_tuple(offset_stats)
: std::make_tuple(nullptr);
auto offset_s_tuple =
use_ragged_stats ? std::make_tuple(offset_stats) : std::make_tuple(nullptr);
auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset)
: std::make_tuple(nullptr, nullptr);

Expand Down Expand Up @@ -443,7 +450,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
size_t seqlen_offsets_workspace_size = 0;
if (is_ragged_q || is_ragged_kv) {
size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset;
} else {
seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset;
Expand Down Expand Up @@ -510,7 +517,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
}
void *devOffsetsS = nullptr;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
devOffsetsS = static_cast<int8_t *>(devOffsets) +
(static_cast<int>(is_ragged_q) + static_cast<int>(is_ragged_kv)) * 2 *
num_bytes_per_ragged_offset;
Expand All @@ -529,7 +536,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
}
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
variant_pack[offset_stats] = devOffsetsS;
}
}
Expand Down Expand Up @@ -587,6 +594,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120;

NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
Expand All @@ -598,13 +606,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t actual_b = b;
if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv;
// On SM 120, cuDNN support check requires BHSD-like strides with max_seqlen (see fwd).
if (sm_arch_ != 120) {
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv;
}
}

// We choose between 32-bit and 64-bit offsets depending on need.
// This allows us to support older cuDNN runtimes gracefully.
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
Expand Down Expand Up @@ -765,7 +775,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_name("stats")
.set_dim({b, h, s_q, 1})
.set_data_type(fe::DataType_t::FLOAT));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
offset_stats =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_stats")
Expand All @@ -791,10 +801,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);

if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
sdpa_backward_options.set_max_total_seq_len_q(s_q);
}
if (is_ragged_kv && cudnn_runtime_version >= 90600) {
if (is_ragged_kv && cudnn_runtime_version >= 90600 && sm_arch_ != 120) {
sdpa_backward_options.set_max_total_seq_len_kv(s_kv);
}

Expand Down Expand Up @@ -914,9 +924,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr);
auto offset_kv_tuple =
is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr);
auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600)
? std::make_tuple(offset_stats)
: std::make_tuple(nullptr);
auto offset_s_tuple =
use_ragged_stats ? std::make_tuple(offset_stats) : std::make_tuple(nullptr);
auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset)
: std::make_tuple(nullptr, nullptr);

Expand Down Expand Up @@ -949,7 +958,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
size_t seqlen_offsets_workspace_size = 0;
if (is_ragged_q || is_ragged_kv) {
size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset;
} else {
seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset;
Expand Down Expand Up @@ -1019,7 +1028,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
}
void *devOffsetsS = nullptr;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
devOffsetsS = static_cast<int8_t *>(devOffsets) +
(static_cast<int>(is_ragged_q) + static_cast<int>(is_ragged_kv)) * 2 *
num_bytes_per_ragged_offset;
Expand All @@ -1038,7 +1047,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
}
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
variant_pack[offset_stats] = devOffsetsS;
}
}
Expand Down Expand Up @@ -1102,6 +1111,9 @@ void fused_attn_arbitrary_seqlen_fwd(
devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
}

const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);

void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
Expand All @@ -1128,15 +1140,17 @@ void fused_attn_arbitrary_seqlen_fwd(
if (return_max_logit) {
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Max->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
(sm_arch_ != 120)) {
output_Max->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_Max->data.dtype = DType::kFloat32;
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Sum_Exp->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
(sm_arch_ != 120)) {
output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
Expand All @@ -1145,7 +1159,8 @@ void fused_attn_arbitrary_seqlen_fwd(
} else {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
(sm_arch_ != 120)) {
output_S->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,11 @@ def forward(
softmax_lse_in_packed_format = False
if qkv_format == "thd":
if use_fused_attention:
softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0)
softmax_lse_in_packed_format = get_cudnn_version() >= (
9,
6,
0,
) and get_device_compute_capability() != (12, 0)
else:
softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3

Expand Down
33 changes: 23 additions & 10 deletions transformer_engine/pytorch/attention/dot_product_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,11 +554,15 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if inference_params is not None:
# Temporarily disabling fused attention for kv caching for sm89 irrespective of cuDNN version
# until the cuDNN bug is resolved
if device_compute_capability == (8, 9):
logger.debug("Disabling FusedAttention for KV caching for sm89")
# Temporarily disabling fused attention for kv caching for sm89/sm120 irrespective of
# cuDNN version until the cuDNN bug is resolved.
if device_compute_capability in ((8, 9), (12, 0)):
logger.debug("Disabling FusedAttention for KV caching for sm89/sm120")
use_fused_attention = False
# Temporarily disable FlashAttention for KV caching on sm120
if device_compute_capability == (12, 0):
logger.debug("Disabling FlashAttention for KV caching for sm120")
use_flash_attention = False
if context_parallel:
logger.debug("Disabling all backends for KV caching with context parallelism")
use_flash_attention = False
Expand Down Expand Up @@ -691,12 +695,21 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
)
use_flash_attention = False
if device_compute_capability == (12, 0):
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as qkv_format = thd is"
" not supported for compute capability = sm120"
)
use_fused_attention = False
if cudnn_version < (9, 18, 1):
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as qkv_format = thd is"
" not supported for compute capability = sm120 and cuDNN version < 9.18.1"
)
use_fused_attention = False
elif qkv_layout in {"t3hd", "th3d"}:
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as qkv_layout = %s is not supported for"
" compute capability = sm120",
qkv_layout,
)
use_fused_attention = False

# Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention_3:
Expand Down
19 changes: 14 additions & 5 deletions transformer_engine/pytorch/cpp_extensions/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,22 @@ def fused_attn_fwd(

if return_max_logit:
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
# thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
# thd (newer cuDNN runtimes, non-sm120): output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1]
# thd (older cuDNN runtimes or sm120): output_tensors: out [tq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
# bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
stats = output_tensors[1] + torch.log(output_tensors[2])
amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3)
max_tensor = output_tensors[1]
if qkv_format == "thd" and max_tensor.ndim == 4:
# For THD on older cuDNN runtimes or THD on sm120, stats can be [b, h, sq, 1] with padded
# sequence positions. Exclude those padded positions when computing max_logit.
seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(device=max_tensor.device)
sq_idx = torch.arange(max_tensor.shape[2], device=max_tensor.device).view(1, 1, -1, 1)
valid = sq_idx < seqlens_q.view(-1, 1, 1, 1)
max_tensor = max_tensor.masked_fill(~valid, float("-inf"))
amax_dims = (0, 2) if max_tensor.ndim == 3 else (0, 2, 3)
# Max -> max_logit [h]
max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype)
max_logit = torch.amax(max_tensor, dim=amax_dims).to(dtype=output_tensors[0].dtype)
aux_ctx_tensors = [stats]
aux_ctx_tensors.extend(output_tensors[3:])
return output_tensors[0], aux_ctx_tensors, max_logit
Expand Down