Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3fb19fc
cudnn now returns Stats always and Max only with `return_max_logit=true`
sudhakarsingh27 Feb 12, 2026
5b40701
Merge branch 'main' of github.com:NVIDIA/TransformerEngine
sudhakarsingh27 Feb 12, 2026
5d479ad
fix a typo that caused a bug
sudhakarsingh27 Feb 12, 2026
296fb9f
update doc strings
sudhakarsingh27 Feb 12, 2026
24bfd45
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2026
fd42feb
fix more docs
sudhakarsingh27 Feb 13, 2026
2d7b51b
Merge branch 'fix_return_stats_max_cudnn' of github.com:sudhakarsingh…
sudhakarsingh27 Feb 13, 2026
260380b
Merge branch 'main' into fix_return_stats_max_cudnn
sudhakarsingh27 Feb 13, 2026
7a5ab35
Merge branch 'main' into fix_return_stats_max_cudnn
sudhakarsingh27 Feb 17, 2026
9710810
Merge branch 'main' into fix_return_stats_max_cudnn
sudhakarsingh27 Feb 18, 2026
07db752
Merge branch 'main' into fix_return_stats_max_cudnn
sudhakarsingh27 Feb 19, 2026
f8b1a68
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fix_r…
sudhakarsingh27 Feb 20, 2026
b5b2b9d
Merge branch 'fix_return_stats_max_cudnn' of github.com:sudhakarsingh…
sudhakarsingh27 Feb 20, 2026
8f40cab
fixes from the feedback
sudhakarsingh27 Feb 20, 2026
56e46fd
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fix_r…
sudhakarsingh27 Feb 20, 2026
1102738
Merge branch 'main' into fix_return_stats_max_cudnn
sudhakarsingh27 Feb 23, 2026
7363541
merge main
sudhakarsingh27 Feb 26, 2026
8c0d6a1
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fix_r…
sudhakarsingh27 Mar 10, 2026
d517a13
update cudnn-frontend to v1.19.1
sudhakarsingh27 Mar 10, 2026
e005455
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fix_r…
sudhakarsingh27 Mar 10, 2026
3ae0a34
update the cudnn frontend
sudhakarsingh27 Mar 12, 2026
ef0d7ec
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into fix_r…
sudhakarsingh27 Mar 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 198 files
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
}

const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
bool generate_stats = !return_max_logit;
bool generate_stats = true; // Always return stats
try {
FADescriptor_v1 descriptor{
b,
Expand Down Expand Up @@ -335,7 +335,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options.set_sink_token(softmax_offset);
}

std::shared_ptr<fe::graph::Tensor_attributes> Max, Sum_Exp;
std::shared_ptr<fe::graph::Tensor_attributes> Max;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
offset_stats =
mha_graph->tensor(fe::graph::Tensor_attributes()
Expand All @@ -349,19 +349,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_name("Max")
.set_dim({b, h, s_q, 1})
.set_data_type(fe::DataType_t::FLOAT));
Sum_Exp = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Sum_Exp")
.set_dim({b, h, s_q, 1})
.set_data_type(fe::DataType_t::FLOAT));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
Max->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
Sum_Exp->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
} else {
Max->set_stride({h * s_q, s_q, 1, 1});
Sum_Exp->set_stride({h * s_q, s_q, 1, 1});
}
sdpa_options.set_logit_max(Max);
sdpa_options.set_score_sum_exp(Sum_Exp);
}

auto [O, Stats] = mha_graph->sdpa(Q, K, V, std::move(sdpa_options));
Expand All @@ -379,13 +372,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
O->set_ragged_offset(offset_o);
}

if (!return_max_logit) {
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1});
if (is_ragged_q && cudnn_runtime_version >= 90600) {
Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
} else {
Stats->set_stride({h * s_q, s_q, 1, 1});
}
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1});
if (is_ragged_q && cudnn_runtime_version >= 90600) {
Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
} else {
Stats->set_stride({h * s_q, s_q, 1, 1});
}

std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
Expand All @@ -395,7 +386,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>> // O
key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O);
auto Stats_tuple =
generate_stats ? std::make_tuple(Stats, nullptr) : std::make_tuple(Max, Sum_Exp);
return_max_logit ? std::make_tuple(Stats, Max) : std::make_tuple(Stats, nullptr);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto softmax_offset_tuple =
is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr);
Expand Down Expand Up @@ -1125,6 +1116,16 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();

Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;

if (return_max_logit) {
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Max->data.dptr = nullptr;
Expand All @@ -1134,23 +1135,6 @@ void fused_attn_arbitrary_seqlen_fwd(
output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_Max->data.dtype = DType::kFloat32;
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Sum_Exp->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_Sum_Exp->data.dtype = DType::kFloat32;
} else {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
}

Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Expand All @@ -1174,14 +1158,12 @@ void fused_attn_arbitrary_seqlen_fwd(

Aux_CTX_Tensors->size = i;
} else if (Aux_CTX_Tensors->size >= 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS1 = output_S->data.dptr;

if (return_max_logit) {
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS1 = output_Max->data.dptr;
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS2 = output_Sum_Exp->data.dptr;
} else {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS1 = output_S->data.dptr;
devPtrS2 = output_Max->data.dptr;
}
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr;
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/common/fused_attn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,23 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t o_tensor_type;
cudnn_frontend::DataType_t do_tensor_type;
cudnn_frontend::DataType_t dqkv_tensor_type;
bool generate_max_sum_exp;
bool return_max_logit;

bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq,
bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type,
softmax_type, window_size_left, window_size_right, bottom_right_diagonal,
deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type,
dqkv_tensor_type, generate_max_sum_exp) <
dqkv_tensor_type, return_max_logit) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv,
rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right,
rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type,
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
rhs.dqkv_tensor_type, rhs.return_max_logit);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] head_dim_v The head dimension of V.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] return_max_logit Whether to produce Max along with Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] deterministic Whether determinism is required or not.
*/
Expand Down Expand Up @@ -269,7 +269,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] return_max_logit Whether to produce Max along with Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/pytorch/cpp_extensions/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,14 +353,14 @@ def fused_attn_fwd(

if return_max_logit:
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
# thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
stats = output_tensors[1] + torch.log(output_tensors[2])
# thd: output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] (there's no typo here)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the "there's no typo here" :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deliberately added it because I didn't believe it and checked the shapes myself :P


aux_ctx_tensors = [output_tensors[1]] # "Stats"
amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3)
# Max -> max_logit [h]
max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype)
aux_ctx_tensors = [stats]
max_logit = torch.amax(output_tensors[2], dim=amax_dims).to(dtype=output_tensors[0].dtype)
aux_ctx_tensors.extend(output_tensors[3:])
return output_tensors[0], aux_ctx_tensors, max_logit

Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/csrc/extensions/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,16 +259,16 @@ std::vector<py::object> fused_attn_fwd(
// f16_max512 : S [b, h, sq, skv]
// f16_arbitrary:
// return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// return_max_logit=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// return_max_logit=true: S [b, h, sq, 1], Max [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2]
size_t i = 0;
at::Tensor output_tensor;
// intermediate softmax tensor, S or M
// intermediate softmax tensor, S or M (for fp8)
output_tensor =
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
set_tensor_param(i++, output_tensor);
// fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Sum_Exp tensor
// fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Max tensor
if (return_max_logit || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
output_tensor =
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
Expand Down
Loading