diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 34bb729b25..bcacb2f801 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -88,7 +88,6 @@ Tensor make_bf16_operand(const std::string& name, const std::vector& sha return t; } - // Creates an MXFP8 operand with the correct data layout for GEMM. // MXFP8 GEMM requirements (scales are along K dimension): // A transposed -> needs rowwise data/scales @@ -175,8 +174,8 @@ std::vector> make_shapes(ShapeCase scase) { } void run_grouped_gemm_case(const TestParams& params) { -#if CUBLAS_VERSION < 130200 - GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is " +#if CUBLAS_VERSION < 130300 + GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is " << CUBLAS_VERSION << "."; #else if (getDeviceComputeCapability() < blackwellComputeCapability) { @@ -349,7 +348,365 @@ void run_grouped_gemm_case(const TestParams& params) { atol, rtol); } -#endif // CUBLAS_VERSION >= 130200 +#endif // CUBLAS_VERSION >= 130300 +} + +void run_grouped_gemm_discrete_out_case(const TestParams& params) { +#if CUBLAS_VERSION < 130300 + GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is " + << CUBLAS_VERSION << "."; +#else + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; + } + + const std::vector> shapes = make_shapes(params.shape_case); + + const size_t num_gemms = shapes.size(); + std::vector A_tensors; + std::vector B_tensors; + std::vector D_multi; + + A_tensors.reserve(num_gemms); + B_tensors.reserve(num_gemms); + D_multi.reserve(num_gemms); + + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + const std::vector a_shape = params.transa ? std::vector{N, K} + : std::vector{K, N}; + const std::vector b_shape = params.transb ? std::vector{K, M} + : std::vector{M, K}; + switch (params.input_case) { + case InputCase::kFP8Current: { + A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kBF16: { + A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kMXFP8: { + A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape, + /*is_A=*/true, params.transa)); + B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape, + /*is_A=*/false, params.transb)); + break; + } + } + D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), + std::vector{M, N}, + DType::kBFloat16)); + } + + std::vector A_ptrs(num_gemms); + std::vector B_ptrs(num_gemms); + std::vector D_ptrs(num_gemms); + std::vector workspaces(num_gemms); + std::vector workspace_ptrs(num_gemms, nullptr); + std::vector A_views; + std::vector B_views; + A_views.reserve(num_gemms); + B_views.reserve(num_gemms); + + // Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues) + std::vector bias_ptrs(num_gemms, nullptr); + std::vector gelu_ptrs(num_gemms, nullptr); + + const size_t cublas_ws_bytes = 32ull * 1024 * 1024; + + for (size_t i = 0; i < num_gemms; ++i) { + A_ptrs[i] = A_tensors[i].data(); + B_ptrs[i] = B_tensors[i].data(); + D_ptrs[i] = D_multi[i].data(); + workspaces[i] = + Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); + workspace_ptrs[i] = workspaces[i].data(); + A_views.push_back(&A_tensors[i]); + B_views.push_back(&B_tensors[i]); + } + + nvte_multi_tensor_gemm(A_ptrs.data(), + B_ptrs.data(), + D_ptrs.data(), + bias_ptrs.data(), + gelu_ptrs.data(), + static_cast(num_gemms), + params.transa, + params.transb, + false, // grad + workspace_ptrs.data(), + false, // accumulate + false, // use_split_accumulator + 0, // sm_count + 0); + + GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode()); + GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode()); + + std::vector C_tensors; + std::vector D_list_tensors; + C_tensors.reserve(num_gemms); + D_list_tensors.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + (void)K; + if (!params.use_null_c) { + C_tensors.emplace_back( + Tensor("C" + std::to_string(i), std::vector{M, N}, DType::kBFloat16)); + } + D_list_tensors.emplace_back( + Tensor("D_list" + std::to_string(i), std::vector{M, N}, DType::kBFloat16)); + NVTE_CHECK_CUDA(cudaMemset(D_list_tensors.back().rowwise_dptr(), 0, + bytes(D_list_tensors.back().rowwise_shape(), + D_list_tensors.back().dtype()))); + } + + std::vector C_list_ptrs; + std::vector D_list_ptrs; + if (!params.use_null_c) { + C_list_ptrs.reserve(num_gemms); + } + D_list_ptrs.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) { + if (!params.use_null_c) { + C_list_ptrs.push_back(C_tensors[i].data()); + } + D_list_ptrs.push_back(D_list_tensors[i].data()); + } + + // Per-matrix alpha/beta (all 1.0 and 0.0 respectively) + Tensor alpha_tensor("alpha", std::vector{num_gemms}, DType::kFloat32); + Tensor beta_tensor("beta", std::vector{num_gemms}, DType::kFloat32); + std::vector alpha_vals(num_gemms, 1.f); + std::vector beta_vals(num_gemms, 0.f); + NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + + const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); + Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); + Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); + + nvte_grouped_gemm_with_discrete_out(grouped_A.get_handle(), + params.transa, + grouped_B.get_handle(), + params.transb, + params.use_null_c ? nullptr : C_list_ptrs.data(), + params.use_null_c ? 0 : num_gemms, + D_list_ptrs.data(), + num_gemms, + alpha_tensor.data(), + beta_tensor.data(), + setup_ws.data(), + cublas_ws.data(), + nullptr, // config (use defaults) + 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + // Compare results + for (size_t i = 0; i < num_gemms; ++i) { + D_list_tensors[i].to_cpu(); + D_multi[i].to_cpu(); + auto [atol, rtol] = getTolerances(D_multi[i].dtype()); + compareResults("grouped_list_vs_multi", + D_list_tensors[i], + D_multi[i].rowwise_cpu_dptr(), + true, + atol, + rtol); + } +#endif // CUBLAS_VERSION >= 130300 +} + +void run_grouped_gemm_discrete_in_case(const TestParams& params) { +#if CUBLAS_VERSION < 130300 + GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is " + << CUBLAS_VERSION << "."; +#else + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; + } + + const std::vector> shapes = make_shapes(params.shape_case); + + const size_t num_gemms = shapes.size(); + std::vector A_tensors; + std::vector B_tensors; + std::vector D_multi; + + A_tensors.reserve(num_gemms); + B_tensors.reserve(num_gemms); + D_multi.reserve(num_gemms); + + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + const std::vector a_shape = params.transa ? std::vector{N, K} + : std::vector{K, N}; + const std::vector b_shape = params.transb ? std::vector{K, M} + : std::vector{M, K}; + switch (params.input_case) { + case InputCase::kFP8Current: { + A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kBF16: { + A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kMXFP8: { + A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape, + /*is_A=*/true, params.transa)); + B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape, + /*is_A=*/false, params.transb)); + break; + } + } + D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), + std::vector{M, N}, + DType::kBFloat16)); + } + + std::vector A_ptrs(num_gemms); + std::vector B_ptrs(num_gemms); + std::vector D_ptrs(num_gemms); + std::vector workspaces(num_gemms); + std::vector workspace_ptrs(num_gemms, nullptr); + std::vector A_views; + std::vector B_views; + A_views.reserve(num_gemms); + B_views.reserve(num_gemms); + + // Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues) + std::vector bias_ptrs(num_gemms, nullptr); + std::vector gelu_ptrs(num_gemms, nullptr); + + const size_t cublas_ws_bytes = 32ull * 1024 * 1024; + + for (size_t i = 0; i < num_gemms; ++i) { + A_ptrs[i] = A_tensors[i].data(); + B_ptrs[i] = B_tensors[i].data(); + D_ptrs[i] = D_multi[i].data(); + workspaces[i] = + Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); + workspace_ptrs[i] = workspaces[i].data(); + A_views.push_back(&A_tensors[i]); + B_views.push_back(&B_tensors[i]); + } + + nvte_multi_tensor_gemm(A_ptrs.data(), + B_ptrs.data(), + D_ptrs.data(), + bias_ptrs.data(), + gelu_ptrs.data(), + static_cast(num_gemms), + params.transa, + params.transb, + false, // grad + workspace_ptrs.data(), + false, // accumulate + false, // use_split_accumulator + 0, // sm_count + 0); + + GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode()); + + std::vector C_tensors; + std::vector D_group_tensors; + C_tensors.reserve(num_gemms); + D_group_tensors.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + (void)K; + if (!params.use_null_c) { + C_tensors.emplace_back(Tensor("C" + std::to_string(i), + std::vector{M, N}, + DType::kBFloat16)); + } + D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i), + std::vector{M, N}, + DType::kBFloat16)); + NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, + bytes(D_group_tensors.back().rowwise_shape(), + D_group_tensors.back().dtype()))); + } + + std::vector C_views, D_views; + for (size_t i = 0; i < num_gemms; ++i) { + if (!params.use_null_c) { + C_views.push_back(&C_tensors[i]); + } + D_views.push_back(&D_group_tensors[i]); + } + + std::optional grouped_C; + if (!params.use_null_c) { + grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); + } + GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); + + // Per-matrix alpha/beta (all 1.0 and 0.0 respectively) + Tensor alpha_tensor("alpha", std::vector{num_gemms}, DType::kFloat32); + Tensor beta_tensor("beta", std::vector{num_gemms}, DType::kFloat32); + std::vector alpha_vals(num_gemms, 1.f); + std::vector beta_vals(num_gemms, 0.f); + NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + + const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); + Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); + Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); + + std::vector A_list_ptrs; + A_list_ptrs.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) { + A_list_ptrs.push_back(A_tensors[i].data()); + } + + nvte_grouped_gemm_with_discrete_inputA(A_list_ptrs.data(), + num_gemms, + params.transa, + grouped_B.get_handle(), + params.transb, + params.use_null_c ? nullptr : grouped_C->get_handle(), + grouped_D.get_handle(), + alpha_tensor.data(), + beta_tensor.data(), + setup_ws.data(), + cublas_ws.data(), + nullptr, // config (use defaults) + 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + // Compare results + for (size_t i = 0; i < num_gemms; ++i) { + Tensor grouped_split("grouped_D" + std::to_string(i), + std::vector{static_cast(std::get<0>(shapes[i])), + static_cast(std::get<1>(shapes[i]))}, + D_multi[i].dtype()); + const size_t offset_bytes = static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), + static_cast(grouped_D.get_data()) + offset_bytes, + grouped_D.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + grouped_split.to_cpu(); + D_multi[i].to_cpu(); + auto [atol, rtol] = getTolerances(D_multi[i].dtype()); + compareResults("grouped_discrete_in_vs_multi", + grouped_split, + D_multi[i].rowwise_cpu_dptr(), + true, + atol, + rtol); + } +#endif // CUBLAS_VERSION >= 130300 } class GroupedGemmTest : public ::testing::TestWithParam {}; @@ -358,6 +715,14 @@ TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) { run_grouped_gemm_case(GetParam()); } +TEST_P(GroupedGemmTest, CompareWithMultiTensorGemmDiscreteOut) { + run_grouped_gemm_discrete_out_case(GetParam()); +} + +TEST_P(GroupedGemmTest, CompareWithMultiTensorGemmDiscreteIn) { + run_grouped_gemm_discrete_in_case(GetParam()); +} + std::string MakeGroupedGemmTestName(const testing::TestParamInfo& info) { constexpr const char* kInputNames[] = {"FP8Current", "BF16", "MXFP8"}; constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"}; diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 694b348a9b..8389989efe 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -110,6 +110,115 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row } } +// Zero out padding in a scale_inv CPU buffer so that the CPU reference +// matches the kernel, which zeroes elements outside the original dims. +// The buffer is stored in leading-dim-major order (row-major for rowwise, +// column-major for colwise). `padded_rows x padded_cols` is the full +// (padded) shape; `orig_rows` / `orig_cols` are the unpadded extents. +static void zero_scale_inv_padding(uint8_t *buf, + size_t padded_rows, size_t padded_cols, + size_t orig_rows, size_t orig_cols) { + for (size_t r = 0; r < padded_rows; ++r) { + for (size_t c = 0; c < padded_cols; ++c) { + if (r >= orig_rows || c >= orig_cols) { + buf[r * padded_cols + c] = 0; + } + } + } +} + +void performTestGroupedSwizzleMXFP8(const int num_tensors, const size_t M, const size_t K) { + using namespace transformer_engine; + using namespace test; + + std::vector> input_tensors; + std::vector> output_tensors; + std::vector input_ptrs; + std::vector output_ptrs; + input_tensors.reserve(num_tensors); + output_tensors.reserve(num_tensors); + input_ptrs.reserve(num_tensors); + output_ptrs.reserve(num_tensors); + + constexpr size_t BLOCK_SIZE = 32; + const std::vector shape{M, K}; + for (int i = 0; i < num_tensors; ++i) { + auto input = std::make_unique("input_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + auto output = std::make_unique("output_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + fillUniform(input.get()); + fillUniform(output.get()); + + // The grouped swizzle kernel zeroes scale_inv elements that fall + // outside the original (unpadded) dimensions. Mirror that in the + // per-tensor CPU buffers so the CPU reference produces identical output. + input->to_cpu(); + const NVTEShape rs = input->rowwise_scale_inv_shape(); + zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr(), + rs.data[0], rs.data[1], + M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); + const NVTEShape cs = input->columnwise_scale_inv_shape(); + zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr(), + cs.data[0], cs.data[1], + (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); + input->from_cpu(); + + input_ptrs.push_back(input.get()); + output_ptrs.push_back(output.get()); + input_tensors.emplace_back(std::move(input)); + output_tensors.emplace_back(std::move(output)); + } + + GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING); + GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING); + const uint8_t input_swizzled = 0; + nvte_set_grouped_tensor_param(grouped_input.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &input_swizzled, sizeof(input_swizzled)); + const uint8_t output_swizzled = 1; + nvte_set_grouped_tensor_param(grouped_output.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &output_swizzled, sizeof(output_swizzled)); + + const NVTEShape row_shape = input_tensors[0]->rowwise_scale_inv_shape(); + const NVTEShape col_shape = input_tensors[0]->columnwise_scale_inv_shape(); + const size_t row_numel = row_shape.data[0] * row_shape.data[1]; + const size_t col_numel = col_shape.data[0] * col_shape.data[1]; + + NVTE_CHECK_CUDA(cudaMemset(grouped_output.scale_inv.get(), 0, num_tensors * row_numel)); + NVTE_CHECK_CUDA(cudaMemset(grouped_output.columnwise_scale_inv.get(), 0, num_tensors * col_numel)); + + nvte_swizzle_grouped_scaling_factors(grouped_input.get_handle(), + grouped_output.get_handle(), 0); + + std::vector output_row(num_tensors * row_numel); + std::vector output_col(num_tensors * col_numel); + NVTE_CHECK_CUDA(cudaMemcpy(output_row.data(), grouped_output.scale_inv.get(), + output_row.size(), cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(output_col.data(), grouped_output.columnwise_scale_inv.get(), + output_col.size(), cudaMemcpyDeviceToHost)); + + std::vector ref_row(num_tensors * row_numel); + std::vector ref_col(num_tensors * col_numel); + for (int i = 0; i < num_tensors; ++i) { + compute_ref_swizzle<128, 4, true>(input_tensors[i]->rowwise_cpu_scale_inv_ptr(), + ref_row.data() + i * row_numel, + row_shape.data[0], row_shape.data[1]); + compute_ref_swizzle<128, 4, false>( + input_tensors[i]->columnwise_cpu_scale_inv_ptr(), + ref_col.data() + i * col_numel, + col_shape.data[1], col_shape.data[0]); + } + + compareResults("grouped_swizzle_rowwise", output_row.data(), ref_row.data(), + num_tensors * row_numel); + compareResults("grouped_swizzle_colwise", output_col.data(), ref_col.data(), + num_tensors * col_numel); +} + class SwizzleTestSuite : public ::testing::TestWithParam, std::pair, bool>> {}; @@ -126,6 +235,41 @@ TEST_P(SwizzleTestSuite, TestSwizzle) { transa); } +class SwizzleGroupedTestSuite + : public ::testing::TestWithParam> {}; + +TEST_P(SwizzleGroupedTestSuite, TestGroupedSwizzleMXFP8) { + const auto num_tensors = std::get<0>(GetParam()); + const auto M = std::get<1>(GetParam()); + const auto K = std::get<2>(GetParam()); + performTestGroupedSwizzleMXFP8(num_tensors, M, K); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwizzleGroupedTestSuite, + ::testing::Values( + // M and K both divisible by 128 + std::make_tuple(3, 256, 256), + std::make_tuple(4, 128, 128), + // M not divisible by 128 + std::make_tuple(3, 200, 256), + std::make_tuple(2, 65, 256), + // K not divisible by 128 + std::make_tuple(3, 256, 160), + std::make_tuple(2, 256, 96), + // Neither M nor K divisible by 128 + std::make_tuple(3, 200, 160), + std::make_tuple(4, 33, 64), + std::make_tuple(2, 1, 32) + ), + [](const testing::TestParamInfo& info) { + return "n" + std::to_string(std::get<0>(info.param)) + + "_M" + std::to_string(std::get<1>(info.param)) + + "_K" + std::to_string(std::get<2>(info.param)); + } +); + namespace { std::vector> num_tiles = { diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index b97afbc191..75d450b46b 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -18,6 +18,7 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te import transformer_engine.pytorch.ops as te_ops + from transformer_engine.pytorch.ops.fused import ( BackwardActivationBias, BackwardAddRMSNorm, @@ -35,6 +36,8 @@ NVFP4Quantizer, is_bf16_available, ) +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor +from transformer_engine.pytorch.cpp_extensions.gemm import general_grouped_gemm_for_grouped_tensor import transformer_engine_torch as tex # Import utility functions @@ -2008,6 +2011,7 @@ def test_dropout( @pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("weight_requires_grad", (False, True)) + @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) def test_grouped_linear( self, *, @@ -2022,6 +2026,7 @@ def test_grouped_linear( quantized_weight: bool, input_requires_grad: bool, weight_requires_grad: bool, + delay_wgrad_compute: bool, ) -> None: """Grouped GEMM""" @@ -2102,6 +2107,7 @@ def test_grouped_linear( bias=bias, device=device, dtype=dtype, + delay_wgrad_compute=delay_wgrad_compute, ) with torch.no_grad(): for group_idx in range(group_size): @@ -2117,6 +2123,8 @@ def test_grouped_linear( y_test = op(x_test, split_sizes) if input_requires_grad or weight_requires_grad: y_test.backward(dy_test) + if delay_wgrad_compute and weight_requires_grad: + op.backward_dw() # Expected numerical error tols = dtype_tols(dtype) @@ -3236,7 +3244,11 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("single_grouped_weight", (False, True)) + @pytest.mark.parametrize("single_grouped_bias", (False, True)) + @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) @pytest.mark.parametrize("glu_interleave_size", (None, 32)) + @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) def test_grouped_mlp( self, *, @@ -3245,14 +3257,18 @@ def test_grouped_mlp( hidden_size: int = 256, dtype: torch.dtype, quantization: Optional[str], + single_grouped_weight: bool, + single_grouped_bias: bool, + accumulate_into_main_grad: bool, device: torch.device = "cuda", split_alignment: int = 256, glu_interleave_size: Optional[int], + delay_wgrad_compute: bool, ) -> None: """GroupedLinear + ScaledSwiGLU + GroupedLinear""" # Split sizes - split_sizes = [split_alignment * i for i in range(group_size)] + split_sizes = [split_alignment * (i) for i in range(group_size)] random.shuffle(split_sizes) split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) @@ -3263,8 +3279,15 @@ def test_grouped_mlp( # Skip invalid configurations with_quantization = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + if single_grouped_weight and quantization != "mxfp8": + pytest.skip("single_grouped_weight is only supported for MXFP8 quantization") + if single_grouped_bias and not bias: + pytest.skip("single_grouped_bias requires bias=True") if with_quantization and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if quantization == "mxfp8" and bias: + # Will be supported in future CUDNN release. + pytest.skip("Bias/dbias not yet supported in MXFP8 fused grouped MLP") # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -3370,6 +3393,10 @@ def test_grouped_mlp( bias=bias, device=device, dtype=dtype, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, ) fc2 = te_ops.GroupedLinear( group_size, @@ -3378,6 +3405,10 @@ def test_grouped_mlp( bias=bias, device=device, dtype=dtype, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, ) module = te_ops.Sequential( fc1, @@ -3387,18 +3418,87 @@ def test_grouped_mlp( # Copy weights with torch.no_grad(): + if single_grouped_weight: + fc1_weights = fc1.weight.quantized_tensors + if fc1_weights is None: + fc1_weights = fc1.weight.split_into_quantized_tensors() + fc2_weights = fc2.weight.quantized_tensors + if fc2_weights is None: + fc2_weights = fc2.weight.split_into_quantized_tensors() for group_idx in range(group_size): - getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) - getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) + if single_grouped_weight: + fc1_weights[group_idx].copy_(fc1_ws_test[group_idx]) + fc2_weights[group_idx].copy_(fc2_ws_test[group_idx]) + else: + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) if bias: - getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) - getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) + if single_grouped_bias: + fc1_bparts = fc1.bias.split_into_quantized_tensors() + fc2_bparts = fc2.bias.split_into_quantized_tensors() + fc1_bparts[group_idx].reshape(-1).copy_(fc1_bs_test[group_idx]) + fc2_bparts[group_idx].reshape(-1).copy_(fc2_bs_test[group_idx]) + else: + getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) + getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) + if accumulate_into_main_grad: + if single_grouped_weight: + fc1.weight.main_grad = torch.full( + fc1.weight.size(), + 0.5, + device=device, + dtype=torch.float32, + ) + fc2.weight.main_grad = torch.full( + fc2.weight.size(), + 0.5, + device=device, + dtype=torch.float32, + ) + else: + for group_idx in range(group_size): + getattr(fc1, f"weight{group_idx}").main_grad = torch.full( + getattr(fc1, f"weight{group_idx}").size(), + 0.5, + device=device, + dtype=torch.float32, + ) + getattr(fc2, f"weight{group_idx}").main_grad = torch.full( + getattr(fc2, f"weight{group_idx}").size(), + 0.5, + device=device, + dtype=torch.float32, + ) del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test # Fuse ops and perform forward and backward pass with te.autocast(enabled=with_quantization, recipe=recipe): y_test = module(x_test, split_sizes, probs_test, split_sizes) y_test.backward(dy_test) + if delay_wgrad_compute: + fc1.backward_dw() + fc2.backward_dw() + + # Check for expected fusions + if ( + quantization == "mxfp8" + and dtype in (torch.bfloat16, torch.float16) + and glu_interleave_size == 32 + ): + if te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + forward_ops = module._module_groups[0]._forward_ops + assert len(forward_ops) == 1 + assert isinstance( + forward_ops[0][0], + te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + ) + if te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): + backward_ops = module._module_groups[0]._backward_ops + assert len(backward_ops) == 1 + assert isinstance( + backward_ops[0][0], + te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, + ) # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} @@ -3410,10 +3510,286 @@ def test_grouped_mlp( assert_close_grads(x_test, x_ref, **tols) assert_close_grads(probs_test, probs_ref, **tols) for group_idx in range(group_size): - assert_close_grads(getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols) - assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols) - assert_close_grads(getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols) - assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols) + if bias: + if single_grouped_bias: + assert_close( + fc2.bias.grad[group_idx], + fc2_bs_ref[group_idx].grad, + **tols, + ) + assert_close( + fc1.bias.grad[group_idx], + fc1_bs_ref[group_idx].grad, + **tols, + ) + else: + assert_close_grads( + getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols + ) + assert_close_grads( + getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols + ) + if not single_grouped_weight and not accumulate_into_main_grad: + assert_close_grads( + getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols + ) + assert_close_grads( + getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols + ) + fc1_w_ref_grad = torch.stack([w.grad for w in fc1_ws_ref], dim=0) + fc2_w_ref_grad = torch.stack([w.grad for w in fc2_ws_ref], dim=0) + if accumulate_into_main_grad: + if single_grouped_weight: + fc1_w_test_grad = fc1.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5 + fc2_w_test_grad = fc2.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5 + else: + fc1_w_test_grad = torch.stack( + [ + getattr(fc1, f"weight{group_idx}").main_grad.to( + dtype=torch.float64, device="cpu" + ) + - 0.5 + for group_idx in range(group_size) + ], + dim=0, + ) + fc2_w_test_grad = torch.stack( + [ + getattr(fc2, f"weight{group_idx}").main_grad.to( + dtype=torch.float64, device="cpu" + ) + - 0.5 + for group_idx in range(group_size) + ], + dim=0, + ) + assert_close(fc1_w_test_grad, fc1_w_ref_grad, **tols) + assert_close(fc2_w_test_grad, fc2_w_ref_grad, **tols) + elif single_grouped_weight: + assert_close(fc1.weight.grad, fc1_w_ref_grad, **tols) + assert_close(fc2.weight.grad, fc2_w_ref_grad, **tols) + + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("single_grouped_weight", (False, True)) + @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_grouped_mlp_cuda_graph_safe_mxfp8( + self, + *, + dtype: torch.dtype, + single_grouped_weight: bool, + accumulate_into_main_grad: bool, + device: torch.device = "cuda", + group_size: int = 4, + hidden_size: int = 256, + split_alignment: int = 256, + glu_interleave_size: int = 32, + ) -> None: + """Grouped MLP forward+backward should be CUDA graph capturable (MXFP8).""" + + if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + pytest.skip("MXFP8 fused grouped MLP is not supported on this system") + if dtype not in (torch.bfloat16, torch.float16): + pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16") + + split_sizes = [split_alignment * (i + 1) for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) + in_shape = (split_sizes.sum().item(), hidden_size) + + recipe = make_recipe("mxfp8") + with te.quantized_model_init(enabled=True, recipe=recipe): + fc1 = te_ops.GroupedLinear( + group_size, + hidden_size, + 2 * hidden_size, + bias=False, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + fc2 = te_ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=False, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + module = te_ops.Sequential( + fc1, + te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size), + fc2, + ) + + def _init_main_grads(value: float = 0.0) -> None: + if not accumulate_into_main_grad: + return + with torch.no_grad(): + if single_grouped_weight: + if getattr(fc1.weight, "main_grad", None) is None: + fc1.weight.main_grad = torch.empty( + fc1.weight.size(), + device=device, + dtype=torch.float32, + ) + if getattr(fc2.weight, "main_grad", None) is None: + fc2.weight.main_grad = torch.empty( + fc2.weight.size(), + device=device, + dtype=torch.float32, + ) + fc1.weight.main_grad.fill_(value) + fc2.weight.main_grad.fill_(value) + else: + for group_idx in range(group_size): + fc1_weight = getattr(fc1, f"weight{group_idx}") + fc2_weight = getattr(fc2, f"weight{group_idx}") + if getattr(fc1_weight, "main_grad", None) is None: + fc1_weight.main_grad = torch.empty( + fc1_weight.size(), + device=device, + dtype=torch.float32, + ) + if getattr(fc2_weight, "main_grad", None) is None: + fc2_weight.main_grad = torch.empty( + fc2_weight.size(), + device=device, + dtype=torch.float32, + ) + fc1_weight.main_grad.fill_(value) + fc2_weight.main_grad.fill_(value) + + def _collect_main_grads() -> tuple[torch.Tensor, torch.Tensor]: + if single_grouped_weight: + fc1_main_grad = fc1.weight.main_grad.detach().clone() + fc2_main_grad = fc2.weight.main_grad.detach().clone() + else: + fc1_main_grad = torch.stack( + [ + getattr(fc1, f"weight{group_idx}").main_grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + fc2_main_grad = torch.stack( + [ + getattr(fc2, f"weight{group_idx}").main_grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + return fc1_main_grad, fc2_main_grad + + static_split_sizes = split_sizes.clone() + + def train_step( + x: torch.Tensor, + probs: torch.Tensor, + dy: torch.Tensor, + out_buf: torch.Tensor, + *, + use_graphed: bool, + ) -> torch.Tensor: + with te.autocast(enabled=True, recipe=recipe): + out = ( + graphed_module(x, static_split_sizes, probs, static_split_sizes) + if use_graphed + else module(x, static_split_sizes, probs, static_split_sizes) + ) + out.backward(dy) + out_buf.copy_(out) + return out_buf + + _init_main_grads(0.0) + + static_x = torch.randn(in_shape, device=device, dtype=dtype, requires_grad=True) + static_probs = torch.randn((in_shape[0],), device=device, dtype=dtype, requires_grad=True) + static_dy = torch.randn(in_shape, device=device, dtype=dtype) + static_out_buf = torch.empty((in_shape[0], hidden_size), device=device, dtype=dtype) + + graphed_module = te.make_graphed_callables( + module, + (static_x, static_split_sizes, static_probs, static_split_sizes), + num_warmup_iters=3, + enabled=True, + recipe=recipe, + ) + + forward_ops = module._module_groups[0]._forward_ops + backward_ops = module._module_groups[0]._backward_ops + assert len(forward_ops) == 1 + assert isinstance( + forward_ops[0][0], + te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + ) + assert len(backward_ops) == 1 + assert isinstance( + backward_ops[0][0], + te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, + ) + + fresh_x = torch.randn_like(static_x) + fresh_probs = torch.randn_like(static_probs) + fresh_dy = torch.randn_like(static_dy) + with torch.no_grad(): + static_x.copy_(fresh_x) + static_probs.copy_(fresh_probs) + static_dy.copy_(fresh_dy) + + for param in module.parameters(): + param.grad = torch.zeros_like(param) + _init_main_grads(0.5) + if static_x.grad is not None: + static_x.grad.zero_() + if static_probs.grad is not None: + static_probs.grad.zero_() + + graph_out = ( + train_step(static_x, static_probs, static_dy, static_out_buf, use_graphed=True) + .detach() + .clone() + ) + torch.cuda.synchronize() + graph_dx = static_x.grad.detach().clone() + graph_dprobs = static_probs.grad.detach().clone() + if accumulate_into_main_grad: + graph_fc1_main_grad, graph_fc2_main_grad = _collect_main_grads() + else: + graph_param_grads = [param.grad.detach().clone() for param in module.parameters()] + + for param in module.parameters(): + param.grad.zero_() + _init_main_grads(0.5) + static_x.grad.zero_() + static_probs.grad.zero_() + + expected_x = fresh_x.detach().clone().requires_grad_(True) + expected_probs = fresh_probs.detach().clone().requires_grad_(True) + expected_dy = fresh_dy.detach().clone() + with te.autocast(enabled=True, recipe=recipe): + expected_out = module( + expected_x, + static_split_sizes, + expected_probs, + static_split_sizes, + ) + expected_out.backward(expected_dy) + + tols = dtype_tols(dtype) + assert_close(graph_out, expected_out, **tols) + assert_close(graph_dx, expected_x.grad, **tols) + assert_close(graph_dprobs, expected_probs.grad, **tols) + if accumulate_into_main_grad: + expected_fc1_main_grad, expected_fc2_main_grad = _collect_main_grads() + assert_close(graph_fc1_main_grad, expected_fc1_main_grad, **tols) + assert_close(graph_fc2_main_grad, expected_fc2_main_grad, **tols) + else: + for graph_grad, param in zip(graph_param_grads, module.parameters()): + assert_close(graph_grad, param.grad, **tols) class TestCustomOps: @@ -3836,3 +4212,145 @@ def fuse_ops( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols) + + +def test_grouped_gemm_quant_cute_matches_mxfp8_quantized() -> None: + if not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Requires SM100+ for grouped GEMM quant kernel.") + + try: + from cudnn import grouped_gemm_quant_wrapper_sm100 # pylint: disable=no-name-in-module + except ImportError as exc: + pytest.skip(f"grouped_gemm_quant_wrapper_sm100 unavailable: {exc}") + + device = torch.device("cuda") + dtype = torch.bfloat16 if is_bf16_available() else torch.float16 + num_groups = 4 + m = 256 + n = 512 + k = 512 + total_m = num_groups * m + split_sizes = torch.full((num_groups,), m, device=device, dtype=torch.int64) + + q = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=False) + q.optimize_for_gemm = False + + torch.manual_seed(0) + a_full = torch.randn(total_m, k, device=device, dtype=dtype) + weights = [torch.randn(n, k, device=device, dtype=dtype) for _ in range(num_groups)] + + grouped_a = tex.group_quantize(a_full, q, num_groups, split_sizes) + a_groups = grouped_a.split_into_quantized_tensors() + b_groups = [q(w) for w in weights] + + # Reference GEMM on dequantized tensors. + ref = torch.empty((total_m, n), device=device, dtype=torch.float32) + start = 0 + for group_idx in range(num_groups): + end = start + m + a_deq = a_groups[group_idx].dequantize(dtype=torch.float32) + b_deq = b_groups[group_idx].dequantize(dtype=torch.float32) + ref[start:end, :] = a_deq @ b_deq.t() + start = end + ref = ref.to(dtype=torch.bfloat16).to(torch.float32) + + # Allocate empty input tensors needed for cuTE DSL kernel + padded_offsets = torch.tensor( + [m * (i + 1) for i in range(num_groups)], + dtype=torch.int32, + device=device, + ) + inputs = { + "a_tensor": torch.empty(1, total_m, k, dtype=torch.float8_e4m3fn, device=device).permute( + 1, 2, 0 + ), + "b_tensor": torch.empty(num_groups, n, k, dtype=torch.float8_e4m3fn, device=device).permute( + 1, 2, 0 + ), + "sfa_tensor": torch.empty( + 1, + total_m // 128, + k // 128, + 32, + 4, + 4, + dtype=torch.float8_e8m0fnu, + device=device, + ).permute(3, 4, 1, 5, 2, 0), + "sfb_tensor": torch.empty( + num_groups, + n // 128, + k // 128, + 32, + 4, + 4, + dtype=torch.float8_e8m0fnu, + device=device, + ).permute(3, 4, 1, 5, 2, 0), + "alpha_tensor": torch.empty(num_groups, dtype=torch.float32, device=device), + "prob_tensor": torch.empty(total_m, 1, 1, dtype=torch.float32, device=device), + "padded_offsets_tensor": padded_offsets, + } + # Overwrite inputs with quantized data/scales from MXFP8 quantizer. + a_data = grouped_a.rowwise_data.view(total_m, k).view(dtype=torch.float8_e4m3fn) + a_data = a_data.unsqueeze(0).permute(1, 2, 0).contiguous() + inputs["a_tensor"].copy_(a_data) + + a_scales = grouped_a.scale_inv.view(dtype=torch.float8_e8m0fnu) + a_scales = a_scales.view(1, total_m // 128, 4, 32, k // 128, 4) + a_scales = a_scales.permute(0, 1, 4, 3, 2, 5).contiguous() + a_scales = a_scales.permute(3, 4, 1, 5, 2, 0).contiguous() + inputs["sfa_tensor"].copy_(a_scales) + + b_data = torch.cat([w._rowwise_data.reshape(-1) for w in b_groups]) + b_data = b_data.view(dtype=torch.float8_e4m3fn) + b_data = b_data.view(num_groups, n, k).permute(1, 2, 0).contiguous() + inputs["b_tensor"].copy_(b_data) + + b_scales = torch.cat([w._rowwise_scale_inv for w in b_groups]) + b_scales = b_scales.view(dtype=torch.float8_e8m0fnu) + b_scales = b_scales.view(num_groups, n // 128, 4, 32, k // 128, 4) + b_scales = b_scales.permute(0, 1, 4, 3, 2, 5).contiguous() + b_scales = b_scales.permute(3, 4, 1, 5, 2, 0).contiguous() + inputs["sfb_tensor"].copy_(b_scales) + + inputs["alpha_tensor"].fill_(1.0) + inputs["prob_tensor"].fill_(1.0) + + cute_out = grouped_gemm_quant_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + norm_const_tensor=None, + prob_tensor=inputs["prob_tensor"], + acc_dtype=torch.float32, + c_dtype=torch.bfloat16, + d_dtype=torch.bfloat16, + cd_major="n", + sf_vec_size=32, + discrete_col_sfd=True, + current_stream=None, + ) + + if isinstance(cute_out, dict): + outputs = cute_out + else: + d_tensor, d_col_tensor, amax_tensor, sfd_row_tensor, sfd_col_tensor = cute_out + outputs = { + "d_tensor": d_tensor, + "d_col_tensor": d_col_tensor, + "amax_tensor": amax_tensor, + "sfd_row_tensor": sfd_row_tensor, + "sfd_col_tensor": sfd_col_tensor, + } + + d_cute = outputs["d_tensor"] + if d_cute.dim() == 3: + d_cute = d_cute.squeeze(-1) + tols = dtype_tols(torch.bfloat16) + assert_close(d_cute[:total_m].float(), ref, **tols) diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index 225c6f6759..5de081e74d 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -356,8 +356,9 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: "shape", [[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]], ) + @pytest.mark.parametrize("output_dbias", [False, True]) @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) - def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: + def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]], output_dbias: bool) -> None: """Test grouped quantization for MXFP8 against per-tensor quantization.""" # Test wont pass until the grouped quantization PR from Oleg is merged. num_tensors = 2 @@ -377,12 +378,21 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: ) # Quantize using grouped API - grouped_output = tex.group_quantize( - grouped_input, - quantizer, - num_tensors, - first_dims, - ) + if output_dbias: + grouped_output, dbias = tex.group_quantize( + grouped_input, + quantizer, + num_tensors, + first_dims, + output_dbias=True, + ) + else: + grouped_output = tex.group_quantize( + grouped_input, + quantizer, + num_tensors, + first_dims, + ) # Build expected output by quantizing each tensor independently expected_data = [] expected_scale_inv = [] @@ -397,8 +407,13 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: assert torch.equal(grouped_output.rowwise_data, expected_data) assert torch.equal(grouped_output.scale_inv, expected_scale_inv) + if output_dbias: + expected_dbias = torch.stack([t.sum(dim=0) for t in input_tensors]) + assert torch.allclose(dbias, expected_dbias) + + @pytest.mark.parametrize("output_dbias", [False, True]) @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) - def test_group_quantize_cudagraph_capturable(self) -> None: + def test_group_quantize_cudagraph_capturable(self, output_dbias: bool) -> None: """Ensure group_quantize is CUDA graph capturable.""" num_tensors = 2 shape = [(512, 1024) for _ in range(num_tensors)] @@ -418,17 +433,31 @@ def test_group_quantize_cudagraph_capturable(self) -> None: static_first_dims = first_dims.clone() # Warmup to initialize kernels and allocator state - _ = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims) + if output_dbias: + _ = tex.group_quantize( + static_input, quantizer, num_tensors, static_first_dims, output_dbias=True + ) + else: + _ = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): - static_output = tex.group_quantize( - static_input, - quantizer, - num_tensors, - static_first_dims, - ) + if output_dbias: + static_output, static_dbias = tex.group_quantize( + static_input, + quantizer, + num_tensors, + static_first_dims, + output_dbias=True, + ) + else: + static_output = tex.group_quantize( + static_input, + quantizer, + num_tensors, + static_first_dims, + ) fresh_input = torch.cat( [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape], @@ -438,9 +467,22 @@ def test_group_quantize_cudagraph_capturable(self) -> None: graph.replay() torch.cuda.synchronize() - expected = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims) - assert torch.equal(static_output.rowwise_data, expected.rowwise_data) - assert torch.equal(static_output.scale_inv, expected.scale_inv) + if output_dbias: + expected_out, expected_dbias = tex.group_quantize( + static_input, + quantizer, + num_tensors, + static_first_dims, + output_dbias=True, + ) + else: + expected_out = tex.group_quantize( + static_input, quantizer, num_tensors, static_first_dims + ) + assert torch.equal(static_output.rowwise_data, expected_out.rowwise_data) + assert torch.equal(static_output.scale_inv, expected_out.scale_inv) + if output_dbias: + assert torch.allclose(static_dbias, expected_dbias) def test_clear(self) -> None: """Test clear method""" @@ -477,7 +519,7 @@ def test_grouped_linear_load_state_dict_multi_to_single_param(self, tmp_path) -> in_features=in_features, out_features=out_features, params_dtype=dtype, - single_grouped_parameter=False, + single_grouped_weight=False, ).cuda() with torch.no_grad(): for i in range(num_gemms): @@ -489,6 +531,7 @@ def test_grouped_linear_load_state_dict_multi_to_single_param(self, tmp_path) -> torch.randn(out_features, device="cuda", dtype=dtype) ) expected_weights = [getattr(src, f"weight{i}").detach().clone() for i in range(num_gemms)] + expected_biases = [getattr(src, f"bias{i}").detach().clone() for i in range(num_gemms)] ckpt_path = tmp_path / "grouped_linear_per_gemm.pt" torch.save(src.state_dict(), ckpt_path) del src @@ -500,7 +543,8 @@ def test_grouped_linear_load_state_dict_multi_to_single_param(self, tmp_path) -> in_features=in_features, out_features=out_features, params_dtype=dtype, - single_grouped_parameter=True, + single_grouped_weight=True, + single_grouped_bias=True, ).cuda() load_result = dst.load_state_dict(src_state_dict, strict=True) assert len(load_result.missing_keys) == 0 @@ -512,6 +556,12 @@ def test_grouped_linear_load_state_dict_multi_to_single_param(self, tmp_path) -> for loaded_weight, expected_weight in zip(loaded_weights, expected_weights): assert torch.equal(loaded_weight, expected_weight) + assert getattr(dst, "bias", None) is not None + loaded_biases = dst.bias.split_into_quantized_tensors() + assert len(loaded_biases) == num_gemms + for loaded_bias, expected_bias in zip(loaded_biases, expected_biases): + assert torch.equal(loaded_bias.reshape(-1), expected_bias.reshape(-1)) + def test_grouped_linear_load_state_dict_single_to_multi_param(self, tmp_path) -> None: """Load grouped-parameter checkpoint from disk into per-GEMM parameter format.""" num_gemms = 3 @@ -524,7 +574,8 @@ def test_grouped_linear_load_state_dict_single_to_multi_param(self, tmp_path) -> in_features=in_features, out_features=out_features, params_dtype=dtype, - single_grouped_parameter=True, + single_grouped_weight=True, + single_grouped_bias=True, ).cuda() with torch.no_grad(): source_weights = src.weight.split_into_quantized_tensors() @@ -533,6 +584,10 @@ def test_grouped_linear_load_state_dict_single_to_multi_param(self, tmp_path) -> torch.randn(out_features, in_features, device="cuda", dtype=dtype) ) expected_weights = [weight.detach().clone() for weight in source_weights] + source_biases = src.bias.split_into_quantized_tensors() + for i in range(num_gemms): + source_biases[i].copy_(torch.randn(out_features, device="cuda", dtype=dtype)) + expected_biases = [b.detach().clone() for b in source_biases] ckpt_path = tmp_path / "grouped_linear_single_param.pt" torch.save(src.state_dict(), ckpt_path) del src @@ -544,7 +599,7 @@ def test_grouped_linear_load_state_dict_single_to_multi_param(self, tmp_path) -> in_features=in_features, out_features=out_features, params_dtype=dtype, - single_grouped_parameter=False, + single_grouped_weight=False, ).cuda() load_result = dst.load_state_dict(src_state_dict, strict=True) assert len(load_result.missing_keys) == 0 @@ -552,3 +607,5 @@ def test_grouped_linear_load_state_dict_single_to_multi_param(self, tmp_path) -> for i, expected_weight in enumerate(expected_weights): assert torch.equal(getattr(dst, f"weight{i}"), expected_weight) + for i, expected_bias in enumerate(expected_biases): + assert torch.equal(getattr(dst, f"bias{i}"), expected_bias.reshape(-1)) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 19b94d3531..a968e9f9a4 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2861,8 +2861,8 @@ def _make_grouped_tensor_uniform( @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) @pytest.mark.parametrize("accumulate", [False, True]) def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate) -> None: - if tex.get_cublasLt_version() < 130200: - pytest.skip("Grouped GEMM requires cuBLAS 13.2+.") + if tex.get_cublasLt_version() < 130300: + pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") if torch.cuda.get_device_capability() < (10, 0): pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") if not is_bf16_available(): @@ -3008,6 +3008,89 @@ def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate) -> No torch.testing.assert_close(o, o_ref, **tols) +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +@pytest.mark.parametrize("accumulate", [False, True]) +@pytest.mark.parametrize("quant_type", ["bf16", "mxfp8"]) +def test_grouped_gemm_grouped_tensor_zero_work(layout, accumulate, quant_type) -> None: + """Grouped GEMM with all-zero split sizes (zero total work). + + For wgrad (NT layout) the output should be zero when not accumulating, + or unchanged when accumulating with beta=1. + """ + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + if quant_type == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + z = 4 + k, n = 256, 256 + dtype = torch.bfloat16 + device = torch.device("cuda") + m_sizes = [0] * z + use_mxfp8 = quant_type == "mxfp8" + + if layout == "NT": + A = [torch.randn(0, k, dtype=dtype, device=device) for _ in range(z)] + B = [torch.randn(0, n, dtype=dtype, device=device) for _ in range(z)] + out = [torch.randn(n, k, dtype=dtype, device=device) for _ in range(z)] + elif layout == "TN": + A = [torch.randn(n, k, dtype=dtype, device=device) for _ in range(z)] + B = [torch.randn(0, k, dtype=dtype, device=device) for _ in range(z)] + out = [torch.randn(0, n, dtype=dtype, device=device) for _ in range(z)] + else: # NN + A = [torch.randn(n, k, dtype=dtype, device=device) for _ in range(z)] + B = [torch.randn(0, n, dtype=dtype, device=device) for _ in range(z)] + out = [torch.randn(0, k, dtype=dtype, device=device) for _ in range(z)] + + out_before = [o.clone() for o in out] + + if use_mxfp8: + transa = layout[0] == "T" + transb = layout[1] == "T" + grouped_A = _make_grouped_tensor_quantized_mxfp8( + A, is_a=True, transposed=transa, device=device + ) + grouped_B = _make_grouped_tensor_quantized_mxfp8( + B, is_a=False, transposed=transb, device=device + ) + else: + grouped_B = _make_grouped_tensor_from_splits(m_sizes, B[0].shape[1], device, dtype) + if layout in ("TN", "NN"): + grouped_A = _make_grouped_tensor_uniform(z, n, k, device, dtype) + _pack_grouped_tensor(grouped_A, A) + else: # NT + grouped_A = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + + if layout == "TN": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + elif layout == "NN": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + else: # NT + grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) + _pack_grouped_tensor(grouped_out, out) + + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out, + layout=layout, + accumulate=accumulate, + ) + + out_result = ( + grouped_out if isinstance(grouped_out, list) else grouped_out.split_into_quantized_tensors() + ) + for i in range(z): + if out_result[i].numel() == 0: + continue + if accumulate: + torch.testing.assert_close(out_result[i], out_before[i]) + else: + torch.testing.assert_close(out_result[i], torch.zeros_like(out_result[i])) + + def _make_grouped_tensor_quantized_mxfp8( tensors: List[torch.Tensor], *, @@ -3050,8 +3133,8 @@ def _make_grouped_tensor_quantized_mxfp8( def test_grouped_gemm_grouped_tensor_mxfp8( shape, accumulate, layout: str, case: str, dtype: torch.dtype ) -> None: - if tex.get_cublasLt_version() < 130200: - pytest.skip("Grouped GEMM requires cuBLAS 13.2+.") + if tex.get_cublasLt_version() < 130300: + pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") if torch.cuda.get_device_capability() < (10, 0): pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") if dtype == torch.bfloat16 and not is_bf16_available(): diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 384b6774f6..17e26ca2f6 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -155,6 +155,18 @@ def check_grouped_weight( ) +def check_grouped_bias(module: GroupedLinear, num_gemms: int, out_features: int): + """Verify GroupedLinear exposes one grouped bias parameter with shape [num_gemms, out_features].""" + bias_params = [(name, p) for name, p in module.named_parameters() if name == "bias"] + assert len(bias_params) == 1, f"Expected 1 grouped bias parameter, got {len(bias_params)}" + name, bias = bias_params[0] + assert name == "bias", f"Expected grouped parameter name 'bias', got {name}" + assert tuple(bias.shape) == (num_gemms, out_features), ( + "Grouped bias has unexpected shape. " + f"Expected {(num_gemms, out_features)}, got {tuple(bias.shape)}" + ) + + def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( (config.max_seqlen_q, config.batch_size, config.hidden_size), @@ -523,13 +535,16 @@ def test_sanity_grouped_linear( ffn_hidden_size, bias=use_bias, params_dtype=dtype, - single_grouped_parameter=single_param, + single_grouped_weight=single_param, + single_grouped_bias=single_param, ).cuda() - # Verify grouped linear exposes a single grouped weight parameter. + # Verify grouped linear exposes a single grouped weight parameter (and bias when applicable). if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()): if single_param: check_grouped_weight(te_grouped_linear, num_gemms, ffn_hidden_size, config.hidden_size) + if use_bias: + check_grouped_bias(te_grouped_linear, num_gemms, ffn_hidden_size) inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b9e2b907e0..7c223e6917 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -150,6 +150,7 @@ list(APPEND transformer_engine_cuda_sources normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu + util/utils.cu util/padding.cu swizzle/swizzle.cu swizzle/swizzle_block_scaling.cu diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 5031a30485..246fc684a1 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -32,7 +32,6 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { // MXFP8 support for grouped GEMM requires cuBLAS 13.3+ #define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130300 // BF16 support for grouped GEMM requires cuBLAS 13.3+ -// cuBLAS 13.2 is mostly functional but contains a bug for wgrad when a group has k=0, the weight gradient will be uninitialized random data instead of zeros. #define CUBLAS_GROUPED_GEMM_VERSION 130300 #if CUBLAS_VERSION >= CUBLAS_GROUPED_GEMM_VERSION @@ -93,12 +92,29 @@ struct TensorShapeInfo { } }; -// Helper functions to compute average dimensions from logical_shape for heuristics -// These are hints for cuBLASLt algorithm selection, don't need to be exact +// Helper functions to compute average dimensions for cuBLASLt algorithm-selection heuristics. +// +// logical_shape encoding (from build_grouped_tensor): +// all_same: {num_tensors * M, N} +// varying_first: {sum_of_first_dims, common_last} +// varying_last: {common_first, sum_of_last_dims} +// varying_both: {1, total_elements} <-- lossy, can't recover per-dim averages +// +// We use all_same_first/last_dim() + get_common_first/last_dim() to get exact +// answers whenever possible, falling back to logical_shape division otherwise. +// For varying_both, per-dim averages are unrecoverable without a D2H copy, +// so we return 1 — a valid non-zero hint that won't skip work. inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) { - // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) - // In both cases, dividing by num_tensors gives the average - return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); + if (t->all_same_first_dim()) { + return static_cast(t->get_common_first_dim()); + } + const int64_t n = static_cast(t->num_tensors); + if (t->all_same_last_dim()) { + // varying_first only: logical_shape = {sum_of_first_dims, common_last} + return static_cast(t->logical_shape.data[0]) / n; + } + // varying_both: logical_shape = {1, total_elements}, no way to recover avg first dim + return 1; } inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) { @@ -228,28 +244,34 @@ inline size_t validate_grouped_gemm_inputs( dtype == transformer_engine::DType::kBFloat16 || dtype == transformer_engine::DType::kFloat16; }; - bool dtype_ok = true; for (const auto *tensor : inputs) { - dtype_ok = dtype_ok && is_supported_input_dtype(tensor->dtype()); + if (tensor->has_data() || tensor->has_columnwise_data()) { + NVTE_CHECK(is_supported_input_dtype(tensor->dtype()), + "Grouped GEMM inputs must be FP8, BF16, or FP16, got ", + transformer_engine::to_string(tensor->dtype()), "."); + } } - NVTE_CHECK(dtype_ok, "Grouped GEMM inputs must be FP8, BF16, or FP16."); + // Cross-operand consistency across all inputs (skip tensors without data). + const transformer_engine::GroupedTensor *ref = nullptr; for (const auto *tensor : inputs) { - NVTE_CHECK(tensor->has_data() || tensor->has_columnwise_data(), - "Grouped GEMM: input tensor is missing both row-wise and column-wise data"); + if (tensor->has_data() || tensor->has_columnwise_data()) { + ref = tensor; + break; + } } - - // Cross-operand consistency across all inputs. - const auto *ref = *inputs.begin(); - const bool ref_is_fp8 = is_fp8_dtype(ref->dtype()); - const bool ref_is_mxfp8 = transformer_engine::is_mxfp_scaling(ref->scaling_mode); - for (const auto *tensor : inputs) { - NVTE_CHECK(is_fp8_dtype(tensor->dtype()) == ref_is_fp8, - "Grouped GEMM: A and B must both be FP8 or both be non-FP8."); - NVTE_CHECK(transformer_engine::is_mxfp_scaling(tensor->scaling_mode) == ref_is_mxfp8, - "Grouped GEMM: A and B must both use MXFP8 scaling or both use tensor scaling."); - if (ref_is_mxfp8) { - NVTE_CHECK(tensor->with_gemm_swizzled_scales, - "MXFP8 grouped GEMM: scales must be swizzled for GEMM."); + if (ref != nullptr) { + const bool ref_is_fp8 = is_fp8_dtype(ref->dtype()); + const bool ref_is_mxfp8 = transformer_engine::is_mxfp_scaling(ref->scaling_mode); + for (const auto *tensor : inputs) { + if (!(tensor->has_data() || tensor->has_columnwise_data())) continue; + NVTE_CHECK(is_fp8_dtype(tensor->dtype()) == ref_is_fp8, + "Grouped GEMM: A and B must both be FP8 or both be non-FP8."); + NVTE_CHECK(transformer_engine::is_mxfp_scaling(tensor->scaling_mode) == ref_is_mxfp8, + "Grouped GEMM: A and B must both use MXFP8 scaling or both use tensor scaling."); + if (ref_is_mxfp8) { + NVTE_CHECK(tensor->with_gemm_swizzled_scales, + "MXFP8 grouped GEMM: scales must be swizzled for GEMM."); + } } } return num_tensors; @@ -554,8 +576,15 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: using namespace transformer_engine; const bool has_row = t->has_data(); const bool has_col = t->has_columnwise_data(); - NVTE_CHECK(has_row || has_col, - "Grouped GEMM operand is missing both row-wise and column-wise data"); + + if (!has_row && !has_col) { + GroupedOperandSelection sel{}; + sel.trans = trans; + sel.scaling_mode = t->scaling_mode; + sel.dtype = t->dtype(); + sel.shape = create_shape_info(t, /*swap_dims=*/false); + return sel; + } const auto sm = t->scaling_mode; const bool mxfp8 = is_mxfp_scaling(sm); @@ -758,7 +787,7 @@ inline void execute_grouped_gemm(const GroupedGemmSetupWorkspace &setup_workspac transformer_engine::DType d_dtype, size_t num_tensors, bool use_split_accumulator, bool use_fp8, int64_t avg_m_val, int64_t avg_n_val, int64_t avg_k_val, void *cublas_workspace_ptr, - cudaStream_t stream) { + cudaStream_t stream, int math_sm_count = 0) { using cublasHandleManager = transformer_engine::detail::HandleManager; cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); @@ -779,7 +808,10 @@ inline void execute_grouped_gemm(const GroupedGemmSetupWorkspace &setup_workspac set_fp8_scale_pointers(matmulDesc, setup_workspace.a_scale_inv_ptrs, setup_workspace.b_scale_inv_ptrs); } - + if (math_sm_count != 0) { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, sizeof(math_sm_count))); + } cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, descD, avg_m_val, avg_n_val, avg_k_val); @@ -824,7 +856,6 @@ __global__ void grouped_bias_add_kernel(char *d_base, const char *bias_base, Ten const int64_t m = d_meta.first_dims ? d_meta.first_dims[tensor_idx] : d_meta.uniform_first; const int64_t n = d_meta.last_dims ? d_meta.last_dims[tensor_idx] : d_meta.uniform_last; - if (m == 0 || n == 0) return; const int64_t d_offset = compute_grouped_tensor_offset(d_meta, tensor_idx); const int64_t bias_offset = compute_grouped_tensor_offset(bias_meta, tensor_idx); @@ -1034,7 +1065,7 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT NVTE_API_CALL(nvte_grouped_gemm); using namespace transformer_engine; - // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.2+ + // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.3+ check_grouped_gemm_requirements("nvte_grouped_gemm"); // Convert to internal types @@ -1082,7 +1113,7 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT const bool use_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); execute_grouped_gemm(workspace.setup_workspace, A_sel, B_sel, outputD->dtype(), num_tensors, config_.use_split_accumulator, use_fp8, avg_m_val, avg_n_val, avg_k_val, - workspace.cublas_workspace_ptr, stream); + workspace.cublas_workspace_ptr, stream, config_.sm_count); } void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num_a_tensors, @@ -1094,7 +1125,7 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num NVTE_API_CALL(nvte_grouped_gemm_with_discrete_inputA); using namespace transformer_engine; - // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.2+ + // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.3+ check_grouped_gemm_requirements("nvte_grouped_gemm_with_discrete_inputA"); NVTE_CHECK(A_list != nullptr, "Grouped GEMM: A_list is null."); @@ -1114,6 +1145,7 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num // Validate inputs and outputs. const size_t num_tensors = validate_grouped_gemm_inputs(num_a_tensors, {inputB}, alpha_tensor, beta_tensor); + validate_grouped_gemm_outputs(num_tensors, {inputC_raw, outputD}); // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) @@ -1200,7 +1232,7 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num const bool use_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); execute_grouped_gemm(workspace.setup_workspace, A_sel, B_sel, outputD->dtype(), num_tensors, config_.use_split_accumulator, use_fp8, avg_m_val, avg_n_val, avg_k_val, - workspace.cublas_workspace_ptr, stream); + workspace.cublas_workspace_ptr, stream, config_.sm_count); } void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, @@ -1213,7 +1245,7 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, NVTE_API_CALL(nvte_grouped_gemm_with_discrete_out); using namespace transformer_engine; - // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.2+ + // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.3+ check_grouped_gemm_requirements("nvte_grouped_gemm_with_discrete_out"); NVTE_CHECK(D_list != nullptr, "Grouped GEMM: D_list is null."); @@ -1272,7 +1304,7 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, const bool use_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); execute_grouped_gemm(workspace.setup_workspace, A_sel, B_sel, d_dtype, num_tensors, config_.use_split_accumulator, use_fp8, avg_m_val, avg_n_val, avg_k_val, - workspace.cublas_workspace_ptr, stream); + workspace.cublas_workspace_ptr, stream, config_.sm_count); } void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, diff --git a/transformer_engine/common/include/transformer_engine/utils.h b/transformer_engine/common/include/transformer_engine/utils.h new file mode 100644 index 0000000000..eca6f359ea --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/utils.h @@ -0,0 +1,36 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file utils.h + * \brief Utility functions (e.g. host-to-device pointer copies). + */ + +#ifndef TRANSFORMER_ENGINE_UTILS_H_ +#define TRANSFORMER_ENGINE_UTILS_H_ + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Copy an array of device pointers (held on host) into a device tensor. + * + * \param[in] host_ptrs Host array of device pointer values cast to uint64_t. + * \param[out] output NVTETensor whose rowwise data buffer receives the pointer values. + * \param[in] count Number of pointers. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_convert_pointers_to_tensor(const uint64_t *host_ptrs, NVTETensor output, int64_t count, + cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_UTILS_H_ diff --git a/transformer_engine/common/util/utils.cu b/transformer_engine/common/util/utils.cu new file mode 100644 index 0000000000..a183e6ec52 --- /dev/null +++ b/transformer_engine/common/util/utils.cu @@ -0,0 +1,51 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include "../common.h" +#include "../util/logging.h" + +namespace { + +constexpr int64_t kMaxKernelAddresses = 256; + +struct HostPointersArgs { + uint64_t ptrs[kMaxKernelAddresses]; +}; + +__global__ void write_pointers_kernel(HostPointersArgs args, uint64_t *out, int64_t count, + int64_t offset) { + const int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < count) { + out[offset + idx] = args.ptrs[idx]; + } +} + +} // namespace + +void nvte_convert_pointers_to_tensor(const uint64_t *host_ptrs, NVTETensor output, int64_t count, + cudaStream_t stream) { + NVTE_API_CALL(nvte_convert_pointers_to_tensor); + using namespace transformer_engine; + Tensor *out_tensor = convertNVTETensorCheck(output); + uint64_t *out_ptr = static_cast(out_tensor->data.dptr); + NVTE_CHECK(out_ptr != nullptr, "Output tensor data pointer is null."); + + int64_t offset = 0; + while (offset < count) { + const int64_t chunk = std::min(kMaxKernelAddresses, count - offset); + HostPointersArgs args{}; + for (int64_t i = 0; i < chunk; ++i) { + args.ptrs[i] = host_ptrs[offset + i]; + } + constexpr int threads = kMaxKernelAddresses; + write_pointers_kernel<<<1, threads, 0, stream>>>(args, out_ptr, chunk, offset); + NVTE_CHECK_CUDA(cudaGetLastError()); + offset += chunk; + } +} diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 63a2e86e67..9d2513835c 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -42,6 +42,7 @@ #include #include #include +#include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1c5116a8da..09919c96c1 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -307,7 +307,7 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob py::object dequantize(const py::handle &input, DType otype); py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, - std::optional first_dims); + std::optional first_dims, bool output_dbias = false); std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); @@ -454,6 +454,12 @@ size_t get_cublasLt_version(); size_t get_cudnn_version(); +std::vector convert_host_pointers_to_tensor( + std::vector> tensor_lists); + +std::tuple get_device_pointer_for_data_and_scales( + std::vector data_tensors, std::vector scale_tensors, bool swizzle, + bool rowwise, transformer_engine::DType data_dtype); at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_dim); /*************************************************************************************************** @@ -561,6 +567,8 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, void inplace_swizzle_scale_for_gemm(py::handle &tensor); +void grouped_swizzle_for_gemm(py::handle &tensor, bool rowwise, bool columnwise); + /*************************************************************************************************** * NVSHMEM APIs **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index cb3434ec52..b57ae008d0 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -159,7 +159,7 @@ void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor, // NOTE: Only supports varying first dim. py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, - std::optional first_dims) { + std::optional first_dims, const bool output_dbias) { using namespace transformer_engine::pytorch::detail; init_extension(); @@ -201,7 +201,13 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const grouped_quantization_mode = GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE; } + NVTE_CHECK(!(output_dbias && + grouped_quantization_mode != GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE), + "group_quantize: output_dbias is only supported for MXFP8 quantizer."); + if (empty_input_buffer) { + NVTE_CHECK(!output_dbias, + "group_quantize: output_dbias is not supported with an empty input tensor."); // early return for empty input buffer // just return the output tensor as is // no need to quantize @@ -217,6 +223,35 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const break; } case GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE: { + if (output_dbias) { + const std::vector dbias_logical_shape = {num_tensors, logical_last_dim}; + GroupedTensorWrapper grouped_dbias(num_tensors, dbias_logical_shape, + NVTE_DELAYED_TENSOR_SCALING); + at::Tensor dbias_torch = + at::empty({static_cast(num_tensors), static_cast(logical_last_dim)}, + tensor.options()); + grouped_dbias.set_rowwise_data(dbias_torch.data_ptr(), + GetTransformerEngineDType(tensor.scalar_type()), + getTensorShape(dbias_torch)); + TensorWrapper workspace_nvte; + auto stream = at::cuda::getCurrentCUDAStream(); + NVTE_SCOPED_GIL_RELEASE({ + nvte_group_quantize_dbias(grouped_input_tensor.data(), grouped_output_tensor_cpp.data(), + grouped_dbias.data(), workspace_nvte.data(), stream); + }); + if (workspace_nvte.ndim() > 0 && workspace_nvte.numel() > 0) { + at::Tensor workspace_torch = + allocateSpace(workspace_nvte.shape(), workspace_nvte.dtype()); + workspace_nvte = makeTransformerEngineTensor( + workspace_torch.data_ptr(), workspace_nvte.shape(), workspace_nvte.dtype()); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_group_quantize_dbias(grouped_input_tensor.data(), grouped_output_tensor_cpp.data(), + grouped_dbias.data(), workspace_nvte.data(), stream); + }); + return py::make_tuple(py::reinterpret_borrow(grouped_output_py), + py::cast(std::move(dbias_torch))); + } NVTE_SCOPED_GIL_RELEASE({ nvte_group_quantize(grouped_input_tensor.data(), grouped_output_tensor_cpp.data(), at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 1431ebdfb4..08470962f9 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -9,9 +9,7 @@ #include #include -#include "../common.h" #include "../extensions.h" -#include "common.h" #include "common/util/cuda_runtime.h" #include "common/util/system.h" #include "pybind.h" @@ -637,8 +635,10 @@ py::object te_general_grouped_gemm_for_grouped_tensor( auto gemm_config = prepare_grouped_gemm_config(alpha, beta, workspace_setup, workspace_cublas, num_tensors, math_sm_count, use_split_accumulator); - [[maybe_unused]] auto swizzled_scales_A = maybe_swizzle_grouped_tensor_for_gemm(grouped_A); - [[maybe_unused]] auto swizzled_scales_B = maybe_swizzle_grouped_tensor_for_gemm(grouped_B); + [[maybe_unused]] auto swizzled_scales_A = + maybe_swizzle_grouped_tensor(grouped_A, transa, !transa); + [[maybe_unused]] auto swizzled_scales_B = + maybe_swizzle_grouped_tensor(grouped_B, transb, !transb); NVTE_SCOPED_GIL_RELEASE({ nvte_grouped_gemm(grouped_A.data(), transa, grouped_B.data(), transb, grouped_D.data(), @@ -704,7 +704,8 @@ py::object te_general_grouped_gemm_for_discrete_in(py::handle A, bool transa, py swizzled_scale_inverses_list.emplace_back( multi_tensor_swizzle_scales_for_gemm(te_A_wrappers, transa, !transa)); - [[maybe_unused]] auto swizzled_scales_B = maybe_swizzle_grouped_tensor_for_gemm(grouped_B); + [[maybe_unused]] auto swizzled_scales_B = + maybe_swizzle_grouped_tensor(grouped_B, transb, !transb); NVTE_SCOPED_GIL_RELEASE({ nvte_grouped_gemm_with_discrete_inputA( @@ -769,8 +770,10 @@ py::object te_general_grouped_gemm_for_discrete_out(py::handle A, bool transa, p te_D_vector.emplace_back(te_D_wrappers.back().data()); } - [[maybe_unused]] auto swizzled_scales_A = maybe_swizzle_grouped_tensor_for_gemm(grouped_A); - [[maybe_unused]] auto swizzled_scales_B = maybe_swizzle_grouped_tensor_for_gemm(grouped_B); + [[maybe_unused]] auto swizzled_scales_A = + maybe_swizzle_grouped_tensor(grouped_A, transa, !transa); + [[maybe_unused]] auto swizzled_scales_B = + maybe_swizzle_grouped_tensor(grouped_B, transb, !transb); NVTE_SCOPED_GIL_RELEASE({ nvte_grouped_gemm_with_discrete_out( diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c590a3c9e2..50b00e8b5a 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -140,7 +140,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); m.def("group_quantize", transformer_engine::pytorch::group_quantize, py::arg("tensor"), - py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); + py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims"), + py::arg("output_dbias") = false); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", @@ -387,6 +388,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Fused Multi-tensor unpadding", py::call_guard()); m.def("swizzle_scales_for_gemm_", &transformer_engine::pytorch::inplace_swizzle_scale_for_gemm, "Convert tensor block scales into GEMM swizzled format"); + m.def("grouped_swizzle_for_gemm", &transformer_engine::pytorch::grouped_swizzle_for_gemm, + "In-place swizzle of grouped tensor scales for GEMM", py::arg("tensor"), py::arg("rowwise"), + py::arg("columnwise")); // attention kernels m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd, @@ -454,6 +458,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Get cublasLt version", py::call_guard()); m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version", py::call_guard()); + m.def("convert_host_pointers_to_tensor", + &transformer_engine::pytorch::convert_host_pointers_to_tensor, + "Copy host-side device pointers into device tensors", py::arg("tensor_lists"), + py::call_guard()); + m.def("get_device_pointer_for_data_and_scales", + &transformer_engine::pytorch::get_device_pointer_for_data_and_scales, + "Swizzle scales and collect data/scale device pointers into device tensors", + py::arg("data_tensors"), py::arg("scale_tensors"), py::arg("swizzle") = false, + py::arg("rowwise"), py::arg("data_dtype"), py::call_guard()); m.def("splits_to_offsets", &transformer_engine::pytorch::splits_to_offsets, "Compute grouped tensor offsets from split sizes", py::arg("first_dims"), py::arg("logical_last_dim"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index 7ff35d6b68..a6b4e7569d 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -338,8 +338,9 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapp return swizzled_scale_inv; } -std::optional maybe_swizzle_grouped_tensor_for_gemm( - GroupedTensorWrapper &input) { +std::optional maybe_swizzle_grouped_tensor(GroupedTensorWrapper &input, + bool rowwise_usage, + bool columnwise_usage) { if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { return std::nullopt; } @@ -349,9 +350,9 @@ std::optional maybe_swizzle_grouped_tensor_for_gemm( const auto row_scales = input.get_rowwise_scale_inv(); const auto col_scales = input.get_columnwise_scale_inv(); - const bool has_rowwise_scales = !is_empty_grouped_tensor_param(row_scales); - const bool has_columnwise_scales = !is_empty_grouped_tensor_param(col_scales); - if (!has_rowwise_scales && !has_columnwise_scales) { + const bool swizzle_rowwise = rowwise_usage && !is_empty_grouped_tensor_param(row_scales); + const bool swizzle_columnwise = columnwise_usage && !is_empty_grouped_tensor_param(col_scales); + if (!swizzle_rowwise && !swizzle_columnwise) { return std::nullopt; } const auto first_dims = input.get_first_dims(); @@ -364,57 +365,84 @@ std::optional maybe_swizzle_grouped_tensor_for_gemm( std::optional rowwise_scales_pyt; std::optional columnwise_scales_pyt; - GroupedTensorWrapper output(input.num_tensors(), input.logical_shape(), input.scaling_mode()); - const auto rowwise_data = input.get_rowwise_data(); - if (rowwise_data.data_ptr != nullptr) { - output.set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - } - const auto columnwise_data = input.get_columnwise_data(); - if (columnwise_data.data_ptr != nullptr) { - output.set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); - } + GroupedTensorWrapper swizzle_input(input.num_tensors(), input.logical_shape(), + input.scaling_mode()); + GroupedTensorWrapper swizzle_output(input.num_tensors(), input.logical_shape(), + input.scaling_mode()); + const auto tensor_offsets = input.get_tensor_offsets(); if (tensor_offsets.data_ptr != nullptr) { - output.set_tensor_offsets(tensor_offsets.data_ptr, static_cast(tensor_offsets.dtype), - tensor_offsets.shape); + swizzle_input.set_tensor_offsets( + tensor_offsets.data_ptr, static_cast(tensor_offsets.dtype), tensor_offsets.shape); + swizzle_output.set_tensor_offsets( + tensor_offsets.data_ptr, static_cast(tensor_offsets.dtype), tensor_offsets.shape); } - if (has_rowwise_scales) { + if (swizzle_rowwise) { + const auto data = input.get_rowwise_data(); + const auto data_dtype = static_cast(data.dtype); const auto scales_dtype = static_cast(row_scales.dtype); + swizzle_input.set_rowwise_data(nullptr, data_dtype, data.shape); + swizzle_input.set_rowwise_scale_inv(row_scales.data_ptr, scales_dtype, row_scales.shape); rowwise_scales_pyt = allocateSpace(row_scales.shape, scales_dtype, false); - void *output_scales_dptr = getDataPtr(*rowwise_scales_pyt); - output.set_rowwise_scale_inv(output_scales_dptr, scales_dtype, row_scales.shape); + swizzle_output.set_rowwise_data(nullptr, data_dtype, data.shape); + swizzle_output.set_rowwise_scale_inv(getDataPtr(*rowwise_scales_pyt), scales_dtype, + row_scales.shape); } - if (has_columnwise_scales) { + if (swizzle_columnwise) { + const auto data = input.get_columnwise_data(); + const auto data_dtype = static_cast(data.dtype); const auto scales_dtype = static_cast(col_scales.dtype); + swizzle_input.set_columnwise_data(nullptr, data_dtype, data.shape); + swizzle_input.set_columnwise_scale_inv(col_scales.data_ptr, scales_dtype, col_scales.shape); columnwise_scales_pyt = allocateSpace(col_scales.shape, scales_dtype, false); - void *output_scales_dptr = getDataPtr(*columnwise_scales_pyt); - output.set_columnwise_scale_inv(output_scales_dptr, scales_dtype, col_scales.shape); + swizzle_output.set_columnwise_data(nullptr, data_dtype, data.shape); + swizzle_output.set_columnwise_scale_inv(getDataPtr(*columnwise_scales_pyt), scales_dtype, + col_scales.shape); } - output.set_with_gemm_swizzled_scales(true); + swizzle_output.set_with_gemm_swizzled_scales(true); NVTE_SCOPED_GIL_RELEASE({ - nvte_swizzle_grouped_scaling_factors(input.data(), output.data(), + nvte_swizzle_grouped_scaling_factors(swizzle_input.data(), swizzle_output.data(), at::cuda::getCurrentCUDAStream()); }); - if (has_rowwise_scales) { + if (swizzle_rowwise) { const auto scales_dtype = static_cast(row_scales.dtype); input.set_rowwise_scale_inv(getDataPtr(*rowwise_scales_pyt), scales_dtype, row_scales.shape); } - if (has_columnwise_scales) { + if (swizzle_columnwise) { const auto scales_dtype = static_cast(col_scales.dtype); input.set_columnwise_scale_inv(getDataPtr(*columnwise_scales_pyt), scales_dtype, col_scales.shape); } input.set_with_gemm_swizzled_scales(true); - return SwizzledGroupedScales{std::move(rowwise_scales_pyt), std::move(columnwise_scales_pyt)}; } +void grouped_swizzle_for_gemm(py::handle &tensor, bool rowwise, bool columnwise) { + using namespace transformer_engine::pytorch::detail; + + auto tensor_nvte = GroupedTensorFromPyTorchGroupedTensor(tensor); + + auto result = maybe_swizzle_grouped_tensor(tensor_nvte, rowwise, columnwise); + + if (result.has_value()) { + if (result->first.has_value()) { + tensor.attr("scale_inv") = py::cast(*result->first); + } else { + tensor.attr("scale_inv") = py::none(); + } + if (result->second.has_value()) { + tensor.attr("columnwise_scale_inv") = py::cast(*result->second); + } else { + tensor.attr("columnwise_scale_inv") = py::none(); + } + tensor.attr("_with_gemm_swizzled_scales") = py::cast(true); + } +} + void inplace_swizzle_scale_for_gemm(py::handle &tensor) { // Convert Python tensor to C++ tensor auto tensor_nvte = makeTransformerEngineTensor(tensor, py::none()); diff --git a/transformer_engine/pytorch/csrc/extensions/utils.cpp b/transformer_engine/pytorch/csrc/extensions/utils.cpp new file mode 100644 index 0000000000..9a093608d4 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/utils.cpp @@ -0,0 +1,165 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include + +#include "common/common.h" +#include "extensions.h" + +namespace transformer_engine::pytorch { + +namespace { + +at::Tensor collect_pointers_in_device_tensor(const std::vector& host_ptrs, + const at::Device& device, cudaStream_t stream) { + const int64_t count = static_cast(host_ptrs.size()); + auto out = at::empty({count}, at::TensorOptions().dtype(at::kLong).device(device)); + auto out_nvte = makeTransformerEngineTensor(out); + nvte_convert_pointers_to_tensor(host_ptrs.data(), out_nvte.data(), count, stream); + return out; +} + +} // namespace + +std::vector convert_host_pointers_to_tensor( + std::vector> tensor_lists) { + std::vector outputs; + outputs.reserve(tensor_lists.size()); + auto stream = at::cuda::getCurrentCUDAStream(); + + for (const auto& tensor_list : tensor_lists) { + NVTE_CHECK(!tensor_list.empty(), "Tensor list is empty."); + const auto& first_tensor = tensor_list[0]; + NVTE_CHECK(first_tensor.is_cuda(), "Tensor list must be on CUDA."); + const auto device = first_tensor.device(); + const int64_t count = static_cast(tensor_list.size()); + std::vector host_ptrs(count); + for (int64_t i = 0; i < count; ++i) { + host_ptrs[i] = reinterpret_cast(tensor_list[static_cast(i)].data_ptr()); + } + outputs.push_back(collect_pointers_in_device_tensor(host_ptrs, device, stream)); + } + + return outputs; +} + +std::tuple get_device_pointer_for_data_and_scales( + std::vector data_tensors, std::vector scale_tensors, bool swizzle, + bool rowwise, transformer_engine::DType data_dtype) { + const size_t num_tensors = data_tensors.size(); + NVTE_CHECK(num_tensors > 0, "data_tensors must not be empty."); + NVTE_CHECK(num_tensors == scale_tensors.size(), + "data_tensors and scale_tensors must have the same size."); + NVTE_CHECK(data_tensors[0].is_cuda(), "data_tensors must be on CUDA."); + const auto device = data_tensors[0].device(); + auto stream = at::cuda::getCurrentCUDAStream(); + + // Infer data shape from the first data tensor (expected 2D: n x k) + NVTE_CHECK(data_tensors[0].dim() == 2, + "data_tensors elements must be 2D, got dim=", data_tensors[0].dim()); + NVTEShape data_shape{}; + data_shape.ndim = 2; + data_shape.data[0] = static_cast(data_tensors[0].size(0)); + data_shape.data[1] = static_cast(data_tensors[0].size(1)); + + // Collect data device pointers + std::vector data_host_ptrs(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + data_host_ptrs[i] = reinterpret_cast(data_tensors[i].data_ptr()); + } + + // Swizzle scales and collect scale pointers + at::Tensor swizzled_scales_keepalive; + std::vector scale_host_ptrs(num_tensors); + + if (swizzle) { + NVTEScalingMode scaling_mode; + transformer_engine::DType scale_dtype; + if (is_fp8_dtype(data_dtype)) { + scaling_mode = NVTE_MXFP8_1D_SCALING; + scale_dtype = transformer_engine::DType::kFloat8E8M0; + } else if (is_fp4_dtype(data_dtype)) { + scaling_mode = NVTE_NVFP4_1D_SCALING; + scale_dtype = transformer_engine::DType::kFloat8E4M3; + } else { + NVTE_ERROR("data_dtype must be an FP8 or FP4 type for swizzling."); + } + + // Compute output buffer size for swizzled scales (16B aligned per tensor) + std::vector output_offsets; + size_t output_bytes = 0; + for (size_t i = 0; i < num_tensors; ++i) { + const size_t scale_numel = static_cast(scale_tensors[i].numel()); + const size_t dtype_bits = transformer_engine::pytorch::typeToNumBits(scale_dtype); + output_bytes = roundup(output_bytes, 16); + output_offsets.push_back(output_bytes); + output_bytes += ceildiv(scale_numel * dtype_bits, 8); + } + + // Allocate single buffer for all swizzled scales + swizzled_scales_keepalive = + allocateSpace(std::vector{output_bytes}, transformer_engine::DType::kByte, false); + uint8_t* output_dptr = reinterpret_cast(getDataPtr(swizzled_scales_keepalive)); + + // Build TensorWrapper input/output pairs and get scale shapes + std::vector inputs_nvte, outputs_nvte; + inputs_nvte.reserve(num_tensors); + outputs_nvte.reserve(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + inputs_nvte.emplace_back(scaling_mode); + outputs_nvte.emplace_back(scaling_mode); + auto& input_nvte = inputs_nvte.back(); + auto& output_nvte = outputs_nvte.back(); + output_nvte.set_with_gemm_swizzled_scales(true); + + NVTEShape scale_shape = convertTorchShape(scale_tensors[i].sizes()); + void* scale_ptr = scale_tensors[i].data_ptr(); + uint8_t* out_scale_ptr = output_dptr + output_offsets[i]; + + if (rowwise) { + input_nvte.set_rowwise_data(nullptr, data_dtype, data_shape); + input_nvte.set_rowwise_scale_inv(scale_ptr, scale_dtype, scale_shape); + output_nvte.set_rowwise_data(nullptr, data_dtype, data_shape); + output_nvte.set_rowwise_scale_inv(out_scale_ptr, scale_dtype, scale_shape); + } else { + input_nvte.set_columnwise_data(nullptr, data_dtype, data_shape); + input_nvte.set_columnwise_scale_inv(scale_ptr, scale_dtype, scale_shape); + output_nvte.set_columnwise_data(nullptr, data_dtype, data_shape); + output_nvte.set_columnwise_scale_inv(out_scale_ptr, scale_dtype, scale_shape); + } + } + + // Pack raw NVTETensors and launch swizzle kernel + std::vector inputs_raw, outputs_raw; + inputs_raw.reserve(num_tensors); + outputs_raw.reserve(num_tensors); + for (auto& t : inputs_nvte) inputs_raw.push_back(t.data()); + for (auto& t : outputs_nvte) outputs_raw.push_back(t.data()); + + nvte_multi_tensor_swizzle_scaling_factors(inputs_raw.data(), outputs_raw.data(), num_tensors, + stream); + + // Collect swizzled scale pointers + for (size_t i = 0; i < num_tensors; ++i) { + scale_host_ptrs[i] = reinterpret_cast(output_dptr + output_offsets[i]); + } + } else { + swizzled_scales_keepalive = at::empty({0}, at::TensorOptions().dtype(at::kByte).device(device)); + for (size_t i = 0; i < num_tensors; ++i) { + scale_host_ptrs[i] = reinterpret_cast(scale_tensors[i].data_ptr()); + } + } + + // Convert pointer arrays to device tensors + auto data_ptrs = collect_pointers_in_device_tensor(data_host_ptrs, device, stream); + auto scale_ptrs = collect_pointers_in_device_tensor(scale_host_ptrs, device, stream); + + return {std::move(data_ptrs), std::move(scale_ptrs), std::move(swizzled_scales_keepalive)}; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 587ec289a4..88f76a7cb1 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -38,10 +38,15 @@ using SwizzledGroupedScales = std::pair, std::optional /*! \brief Swizzle grouped tensor scales for GEMM if needed. * Currently only works for MXFP8 1D scaling with uniform shapes. * + * \param[in,out] input Grouped tensor whose scales to swizzle. + * \param[in] rowwise_usage Whether rowwise scales are needed. + * \param[in] columnwise_usage Whether columnwise scales are needed. + * * The returned swizzled scales should be kept alive during the GEMM. */ -std::optional maybe_swizzle_grouped_tensor_for_gemm( - GroupedTensorWrapper& input); +std::optional maybe_swizzle_grouped_tensor(GroupedTensorWrapper& input, + bool rowwise_usage, + bool columnwise_usage); /*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place. * diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 28da4873f0..a96a87bf89 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -80,19 +80,19 @@ class UserBufferQuantizationMode(Enum): def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: """Returns a dummy tensor of given shape.""" - if len(shape) != 2: - raise ValueError(f"Expected 2D shape, got {len(shape)}D: {shape}") + + key = (*shape, dtype) global _dummy_wgrads - if (shape[0], shape[1], dtype) not in _dummy_wgrads: - _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( + if key not in _dummy_wgrads: + _dummy_wgrads[key] = torch.empty( shape, dtype=dtype, device="cuda", requires_grad=False, ) if zero: - _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0) - return _dummy_wgrads[(shape[0], shape[1], dtype)].detach() + _dummy_wgrads[key].fill_(0) + return _dummy_wgrads[key].detach() def initialize_ub( diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 0adda48e36..ba6becb9f9 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -594,10 +594,14 @@ class GroupedLinear(TransformerEngineBaseModule): cast tensor. In some scenarios, the input tensor is used by multiple modules, and saving the original input tensor may reduce the memory usage. Cannot work with FP8 DelayedScaling recipe. - single_grouped_parameter : bool, default = False + single_grouped_weight : bool, default = False If set to ``True``, grouped weights are stored as a single grouped parameter instead of one parameter per GEMM. EXPERIMENTAL and subject to change. + single_grouped_bias : bool, default = False + If set to ``True``, grouped biases are stored as a single grouped bias + instead of one bias per GEMM. + EXPERIMENTAL and subject to change. Notes ----- @@ -628,7 +632,8 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, save_original_input: bool = False, - single_grouped_parameter: bool = False, + single_grouped_weight: bool = False, + single_grouped_bias: bool = False, name: Optional[str] = None, ) -> None: super().__init__(name) @@ -645,7 +650,8 @@ def __init__( self.ub_overlap_ag = ub_overlap_ag self.ub_name = ub_name self.save_original_input = save_original_input - self.single_grouped_parameter = single_grouped_parameter + self.single_grouped_weight = single_grouped_weight + self.single_grouped_bias = single_grouped_bias if ub_overlap_rs or ub_overlap_ag: raise ValueError("GroupedLinear doesn't support Userbuffer overlap.") self.init_method = init_method @@ -737,6 +743,9 @@ def __init__( if self.wgrad_store.delay_wgrad_compute(): for name, param in self.named_parameters(): + if name in ("weight", "bias"): + param.skip_backward_post_hook = True + continue for i in range(self.num_gemms): if name in (f"weight{i}", f"bias{i}"): param.skip_backward_post_hook = True @@ -787,13 +796,12 @@ def make_grouped_weights(self, defer_init=False) -> None: else: grouped_weights.quantized_tensors[i].copy_(weights[i]) - # Re-register as a single grouped weight parameter. # Re-register as a single grouped weight parameter. if not ( isinstance(grouped_weights, torch.Tensor) and (weight_quantizers[0] is None or not weight_quantizers[0].internal) ): - raise RuntimeError("Found internal quantizer with `single_grouped_parameter=True`.") + raise RuntimeError("Found internal quantizer with `single_grouped_weight=True`.") self.register_parameter( "weight", torch.nn.Parameter(grouped_weights), @@ -804,13 +812,33 @@ def make_grouped_weights(self, defer_init=False) -> None: for i in range(self.num_gemms): self.register_parameter(f"weight{i}", None) + if self.use_bias and self.single_grouped_bias: + self._make_grouped_biases() + self.set_tensor_parallel_attributes(defer_init=defer_init) + def _make_grouped_biases(self) -> None: + """Pack per-GEMM biases into one ``GroupedTensor`` (``single_grouped_bias``).""" + biases = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + packed = torch.stack([b.detach().clone() for b in biases], dim=0).contiguous() + grouped_bias = GroupedTensor.make_grouped_tensor_from_rowwise_data( + num_tensors=self.num_gemms, + tensor_shape=(self.out_features,), + rowwise_data=packed, + dtype=packed.dtype, + ) + grouped_bias.requires_grad_(True) + self.register_parameter("bias", torch.nn.Parameter(grouped_bias)) + for i in range(self.num_gemms): + self.register_parameter(f"bias{i}", None) + def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) - # Grouped tensor weights is an opt-in feature. - if self.single_grouped_parameter: + # Grouped tensor weights / biases are opt-in features. + if self.single_grouped_weight: self.make_grouped_weights(defer_init=defer_init) + elif self.single_grouped_bias: + self._make_grouped_biases() def set_tensor_parallel_attributes(self, defer_init=False) -> None: """Set attributes needed for TP""" @@ -836,15 +864,24 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: # Set parallelism attributes for linear biases if self.use_bias: - for i in range(self.num_gemms): + grouped_bias = getattr(self, "bias", None) + if grouped_bias is not None: if self.parallel_mode == "row": - setattr( - getattr(self, f"bias{i}"), - "sequence_parallel", - self.sequence_parallel, - ) + setattr(grouped_bias, "sequence_parallel", self.sequence_parallel) elif self.parallel_mode == "column": - set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1) + set_tensor_model_parallel_attributes(grouped_bias, True, 0, 1) + else: + for i in range(self.num_gemms): + if self.parallel_mode == "row": + setattr( + getattr(self, f"bias{i}"), + "sequence_parallel", + self.sequence_parallel, + ) + elif self.parallel_mode == "column": + set_tensor_model_parallel_attributes( + getattr(self, f"bias{i}"), True, 0, 1 + ) def _remap_grouped_weight_state_dict_keys(self, state_dict, prefix: str) -> None: """Remap weight keys between single and per-GEMM checkpoint formats.""" @@ -853,8 +890,8 @@ def _remap_grouped_weight_state_dict_keys(self, state_dict, prefix: str) -> None has_grouped_weight = grouped_weight_key in state_dict has_per_gemm_weights = all(key in state_dict for key in per_gemm_weight_keys) - if self.single_grouped_parameter: - # Backward compatibility: checkpoints saved without single_grouped_parameter + if self.single_grouped_weight: + # Backward compatibility: checkpoints saved without single_grouped_weight # store one weight tensor per GEMM (weight0..weightN). Convert them into a # single stacked grouped weight expected by this module configuration. if not has_grouped_weight and has_per_gemm_weights: @@ -869,7 +906,7 @@ def _remap_grouped_weight_state_dict_keys(self, state_dict, prefix: str) -> None for key in per_gemm_weight_keys: state_dict.pop(key, None) else: - # Forward compatibility: checkpoints saved with single_grouped_parameter + # Forward compatibility: checkpoints saved with single_grouped_weight # store one grouped `weight`. Convert it back to weight0..weightN. if not has_per_gemm_weights and has_grouped_weight: grouped_weight = state_dict.pop(grouped_weight_key) @@ -898,6 +935,40 @@ def _remap_grouped_weight_state_dict_keys(self, state_dict, prefix: str) -> None # Drop any redundant grouped key to avoid strict-load unexpected-key errors. state_dict.pop(grouped_weight_key, None) + def _remap_grouped_bias_state_dict_keys(self, state_dict, prefix: str) -> None: + """Remap bias keys between single grouped and per-GEMM checkpoint formats.""" + if not self.use_bias: + return + grouped_bias_key = f"{prefix}bias" + per_gemm_bias_keys = [f"{prefix}bias{i}" for i in range(self.num_gemms)] + has_grouped_bias = grouped_bias_key in state_dict + has_per_gemm_biases = all(key in state_dict for key in per_gemm_bias_keys) + + if self.single_grouped_bias: + if not has_grouped_bias and has_per_gemm_biases: + per_gemm = [state_dict.pop(key) for key in per_gemm_bias_keys] + state_dict[grouped_bias_key] = torch.stack(per_gemm, dim=0) + elif has_grouped_bias: + for key in per_gemm_bias_keys: + state_dict.pop(key, None) + val = state_dict[grouped_bias_key] + if isinstance(val, torch.Tensor) and val.dim() == 3 and val.shape[1] == 1: + state_dict[grouped_bias_key] = val.squeeze(1) + else: + if not has_per_gemm_biases and has_grouped_bias: + gb = state_dict.pop(grouped_bias_key) + if hasattr(gb, "split_into_quantized_tensors"): + members = gb.quantized_tensors + if members is None: + members = gb.split_into_quantized_tensors() + per_gemm = [m.reshape(-1) if m.dim() > 1 else m for m in members] + else: + per_gemm = list(gb.unbind(0)) + for i, b in enumerate(per_gemm): + state_dict[f"{prefix}bias{i}"] = b.reshape(-1) if b.dim() > 1 else b + elif has_per_gemm_biases: + state_dict.pop(grouped_bias_key, None) + def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): """Load state dict with grouped-weight format compatibility.""" state_dict_copy = state_dict.copy() @@ -905,6 +976,7 @@ def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False) if metadata is not None: state_dict_copy._metadata = metadata self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="") + self._remap_grouped_bias_state_dict_keys(state_dict_copy, prefix="") return super().load_state_dict(state_dict_copy, strict=strict, assign=assign) def _load_from_state_dict( @@ -912,6 +984,7 @@ def _load_from_state_dict( ): """Load state, including compatibility across grouped-weight checkpoint formats.""" self._remap_grouped_weight_state_dict_keys(state_dict, prefix) + self._remap_grouped_bias_state_dict_keys(state_dict, prefix) super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs @@ -962,7 +1035,7 @@ def forward( inp = self.prepare_forward(inp, num_gemms=self.num_gemms) try: weight_tensors = self._get_weight_tensors() - bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + bias_tensors = self._get_bias_tensors() quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() @@ -1026,18 +1099,28 @@ def backward_dw(self): """ if not self.need_backward_dw(): return + if self.wgrad_store.context is None or self.wgrad_store.context.empty(): + return with get_nvtx_range_context("_GroupedLinear_wgrad"): (_, grad_biases_, _), tensor_list = self.wgrad_store.pop() wgrad_list = tensor_list[2] weight_params = self._get_weight_tensors() - bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fuse_wgrad_accumulation: for i in range(self.num_gemms): weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype) if self.use_bias: - for i in range(self.num_gemms): - if bias_params[i].grad is None: - bias_params[i].grad = grad_biases_[i].to(bias_params[i].dtype) + grouped_bias = getattr(self, "bias", None) + if grouped_bias is not None: + gstack = torch.stack(grad_biases_, dim=0).to(grouped_bias.dtype) + if grouped_bias.grad is None: + grouped_bias.grad = gstack + else: + grouped_bias.grad.add_(gstack) + else: + bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + for i in range(self.num_gemms): + if bias_params[i].grad is None: + bias_params[i].grad = grad_biases_[i].to(bias_params[i].dtype) del grad_biases_ del wgrad_list del tensor_list @@ -1099,6 +1182,16 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage ] return weight_tensors + def _get_bias_tensors(self) -> List[torch.Tensor]: + """Per-GEMM bias tensors (views into grouped storage when ``single_grouped_bias``).""" + grouped_bias = getattr(self, "bias", None) + if grouped_bias is not None: + parts = grouped_bias.quantized_tensors + if parts is None: + parts = grouped_bias.split_into_quantized_tensors() + return [p.reshape(-1) for p in parts] + return [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" if not self.fp8 and not self.fp8_calibration and not self.primary_weights_in_fp8: diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 4520dbc313..0642fdfec1 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -71,3 +71,116 @@ def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, i fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=tensor.device) fp8_meta.scale_inv = tensor._scale_inv return fp8_meta, 0 + + +def validate_grouped_mlp_dims(fc1, swiglu, fc2) -> None: + """Validate FC1/SwiGLU/FC2 dimensions and interleave size for fused grouped MLP.""" + if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0: + raise ValueError( + f"Unsupported dims for FC1 (num_groups={fc1.num_groups}, " + f"in_features={fc1.in_features}, out_features={fc1.out_features})." + ) + if fc2.in_features % 256 != 0 or fc2.out_features % 256 != 0: + raise ValueError( + f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, " + f"in_features={fc2.in_features}, out_features={fc2.out_features})." + ) + if fc1.out_features != 2 * fc2.in_features or fc1.num_groups != fc2.num_groups: + raise ValueError( + f"FC1 (num_groups={fc1.num_groups}, in_features={fc1.in_features}, " + f"out_features={fc1.out_features}) " + f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, " + f"out_features={fc2.out_features}) do not match." + ) + if swiglu.glu_interleave_size != 32: + raise ValueError( + "Fused kernel requires 32-wide GLU interleaving, " + f"but got glu_interleave_size={swiglu.glu_interleave_size}." + ) + + +def fuse_grouped_mlp_ops( + ops, + *, + recipe, + fused_op_cls, +): + """Sliding-window fusion for GroupedLinear + ScaledSwiGLU + GroupedLinear. + + Parameters + ---------- + ops : list of FusibleOperation + Operations to scan. + recipe : Recipe or None + Quantization recipe. + fused_op_cls : type + Fused operation class with ``is_supported()`` classmethod and + constructor accepting ``fc1``, ``swiglu``, ``fc2`` keyword args. + May also expose ``is_fc1_bias_supported()`` and/or + ``is_fc2_bias_supported()`` classmethods for bias eligibility. + + Returns + ------- + list of FusibleOperation + Updated operations with matched triples replaced by fused ops. + """ + from .basic import GroupedLinear, ScaledSwiGLU # pylint: disable=import-outside-toplevel + + if not fused_op_cls.is_supported(): + return ops + if recipe is None or not recipe.mxfp8(): + return ops + + fc1_bias_ok = ( + not hasattr(fused_op_cls, "is_fc1_bias_supported") or fused_op_cls.is_fc1_bias_supported() + ) + fc2_bias_ok = ( + not hasattr(fused_op_cls, "is_fc2_bias_supported") or fused_op_cls.is_fc2_bias_supported() + ) + + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + + matches_pattern = True + if not ( + isinstance(window[0], GroupedLinear) + and isinstance(window[1], ScaledSwiGLU) + and isinstance(window[2], GroupedLinear) + ): + matches_pattern = False + elif window[0].num_groups != window[2].num_groups: + matches_pattern = False + elif ( + window[0].in_features % 256 != 0 + or window[0].out_features % 256 != 0 + or window[2].in_features % 256 != 0 + or window[2].out_features % 256 != 0 + ): + matches_pattern = False + elif window[1].glu_interleave_size != 32: + matches_pattern = False + elif window[0].has_bias and not fc1_bias_ok: + matches_pattern = False + elif window[2].has_bias and not fc2_bias_ok: + matches_pattern = False + + if matches_pattern: + op = fused_op_cls( + fc1=window[0], + swiglu=window[1], + fc2=window[2], + ) + window = [op] + else: + out.extend(window[:-2]) + window = window[-2:] + + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + out.extend(window) + return out diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index b44e77b0c6..f26a337a4d 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -7,6 +7,7 @@ from __future__ import annotations from collections.abc import Callable, Iterable, Sequence import contextlib +import functools import math from typing import Any, Optional @@ -15,6 +16,7 @@ import transformer_engine_torch as tex from ...cpp_extensions import general_grouped_gemm from ...distributed import CudaRNGStatesTracker +from ...module._common import WeightGradStore from ...module.base import ( _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -32,6 +34,7 @@ ) from .._common import is_quantized_tensor, maybe_dequantize from ..op import BasicOperation, OperationContext +from ...tensor import GroupedTensor class GroupedLinear(BasicOperation): @@ -69,6 +72,13 @@ class GroupedLinear(BasicOperation): Megatron-LM. This argument along with weight tensor having attribute ``overwrite_main_grad`` set to True will overwrite ``main_grad`` instead of accumulating. + single_grouped_weight : bool, default = ``False`` + Store all expert weights as one ``GroupedTensor`` parameter ``weight``. + delay_wgrad_compute : bool, default = ``False`` + Whether to delay weight gradient computation + single_grouped_bias : bool, default = ``False`` + If ``True`` (and ``bias=True``), store all expert biases as one ``GroupedTensor`` + parameter named ``bias`` instead of ``bias0``..``bias{N-1}``. """ @@ -86,13 +96,21 @@ def __init__( dtype: Optional[torch.dtype] = None, rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, accumulate_into_main_grad: bool = False, + single_grouped_weight: bool = False, + single_grouped_bias: bool = False, + delay_wgrad_compute: bool = False, ) -> None: super().__init__() + self.wgrad_store = WeightGradStore(delay_wgrad_compute) + # Weight tensor dimensions self.num_groups: int = num_groups self.in_features: int = in_features self.out_features: int = out_features + self.single_grouped_weight: bool = single_grouped_weight + self.single_grouped_bias: bool = single_grouped_bias + self.use_bias: bool = bias if self.num_groups <= 0: raise ValueError(f"Invalid number of groups ({self.num_groups})") if self.in_features <= 0: @@ -116,12 +134,15 @@ def __init__( self._rng_state_tracker_function = rng_state_tracker_function # Register weights + # TODO(ksivaman): Proper support for meta device. + # We do not want to reset params later as it wipes off + # main_grad and related attributes. self.weight0: torch.nn.Parameter for group_idx in range(self.num_groups): weight_tensor = torch.empty( self.out_features, self.in_features, - device="meta", + device=device, dtype=dtype, ) self.register_parameter( @@ -136,7 +157,7 @@ def __init__( if bias: bias_tensor = torch.empty( self.out_features, - device="meta", + device=device, dtype=dtype, ) bias_tensor = torch.nn.Parameter(bias_tensor) @@ -149,6 +170,57 @@ def __init__( # Whether to accumulate weight gradient into main_grad self._accumulate_into_main_grad: bool = accumulate_into_main_grad + self._apply_delay_wgrad_param_hooks() + + def _apply_delay_wgrad_param_hooks(self) -> None: + """Set ``skip_backward_post_hook`` on weights when delaying wgrad (bias uses main backward).""" + if not self.wgrad_store.delay_wgrad_compute(): + return + if self.single_grouped_weight: + self.weight.skip_backward_post_hook = True + else: + for group_idx in range(self.num_groups): + getattr(self, f"weight{group_idx}").skip_backward_post_hook = True + + def need_backward_dw(self) -> bool: + """Return whether :meth:`backward_dw` must run to finish weight gradients.""" + return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute() + + def backward_dw(self) -> None: + """Execute delayed weight gradient grouped GEMMs (see ``delay_wgrad_compute``).""" + if not self.need_backward_dw(): + return + if self.wgrad_store.context is None or self.wgrad_store.context.empty(): + return + _, tensor_list = self.wgrad_store.pop() + activations = tensor_list[0] + grad_weights = tensor_list[2] + if isinstance(activations, list): + clear_tensor_data(*activations) + else: + # Fused MXFP8 grouped MLP saves `GroupedTensor` activations for wgrad. + clear_tensor_data( + activations.data, + activations.columnwise_data, + activations.scale_inv, + activations.columnwise_scale_inv, + ) + if self._accumulate_into_main_grad: + return + if self.single_grouped_weight: + if isinstance(grad_weights, list): + self.weight.grad = torch.stack(grad_weights, dim=0).to(self.weight.dtype) + else: + self.weight.grad = grad_weights.rowwise_data.view( + self.num_groups, + self.out_features, + self.in_features, + ).to(self.weight.dtype) + else: + for group_idx in range(self.num_groups): + w = getattr(self, f"weight{group_idx}") + w.grad = grad_weights[group_idx].to(w.dtype) + def num_quantizers(self, mode: str) -> int: if mode == "forward": return 2 * self.num_groups @@ -159,7 +231,7 @@ def num_quantizers(self, mode: str) -> int: @property def has_bias(self) -> bool: """Whether an additive bias is being applied""" - return self.bias0 is not None + return self.use_bias def reset_parameters(self) -> None: """Initialize parameter buffers and values""" @@ -221,16 +293,92 @@ def reset_parameters(self) -> None: setattr(self, f"weight{group_idx}", weight) # Initialize biases if needed - if self.bias0 is not None: + packed_biases: Optional[torch.Tensor] = None + if self.use_bias: + if self.bias0 is not None: + bias_dtype = self.bias0.dtype + elif getattr(self, "bias", None) is not None: + bias_dtype = self.bias.dtype + elif getattr(self, "weight", None) is not None: + bias_dtype = self.weight.dtype + else: + bias_dtype = self.weight0.dtype packed_biases = torch.zeros( self.num_groups, self.out_features, - dtype=self.bias0.dtype, + dtype=bias_dtype, device=device, ) + if not self.single_grouped_bias: + for group_idx in range(self.num_groups): + bias = torch.nn.Parameter(packed_biases[group_idx]) + setattr(self, f"bias{group_idx}", bias) + else: for group_idx in range(self.num_groups): - bias = torch.nn.Parameter(packed_biases[group_idx]) - setattr(self, f"bias{group_idx}", bias) + self.register_parameter(f"bias{group_idx}", None) + + if self.single_grouped_weight: + self.make_grouped_weights() + if self.use_bias and self.single_grouped_bias: + assert packed_biases is not None + self._make_grouped_biases_from_packed(packed_biases) + self._apply_delay_wgrad_param_hooks() + + def make_grouped_weights(self) -> None: + """ + Convert parameters into a GroupedTensor and re-register them as parameters. + """ + + weights = [getattr(self, f"weight{idx}") for idx in range(self.num_groups)] + quantizer = self.get_quantizer("forward", 1) + + recipe = None if quantizer is None else quantizer._get_compatible_recipe() + if recipe is not None and (recipe.delayed() or recipe.float8_current_scaling()): + raise RuntimeError( + "Delayed scaling or float8 current scaling is not supported with" + " single_grouped_weight=True" + ) + + grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=self.num_groups, + shapes=[(self.out_features, self.in_features)] * self.num_groups, + quantizer=quantizer, + dtype=self.weight0.dtype, + device=self.weight0.device, + ) + + # Copy existing params into storage. + with torch.no_grad(): + for i in range(self.num_groups): + if self._with_quantized_weight: + grouped_weights.quantized_tensors[i].copy_from_storage(weights[i]) + else: + grouped_weights.quantized_tensors[i].copy_(weights[i]) + + assert isinstance(grouped_weights, torch.Tensor) and ( + quantizer is None or not quantizer.internal + ), "Found internal quantizer with `single_grouped_weight=True`." + + # Re-register as a single grouped weight parameter. + self.register_parameter("weight", torch.nn.Parameter(grouped_weights)) + for group_idx in range(self.num_groups): + self.register_parameter(f"weight{group_idx}", None) + + self._apply_delay_wgrad_param_hooks() + + def _make_grouped_biases_from_packed(self, packed_biases: torch.Tensor) -> None: + """Replace per-group bias parameters with one ``GroupedTensor`` (``single_grouped_bias``).""" + bias_data = packed_biases.detach().clone().contiguous() + grouped_bias = GroupedTensor.make_grouped_tensor_from_rowwise_data( + num_tensors=self.num_groups, + tensor_shape=(self.out_features,), + rowwise_data=bias_data, + dtype=bias_data.dtype, + ) + grouped_bias.requires_grad_(True) + self.register_parameter("bias", torch.nn.Parameter(grouped_bias)) + for group_idx in range(self.num_groups): + self.register_parameter(f"bias{group_idx}", None) def _quantize_weights( self, @@ -328,63 +476,102 @@ def pre_first_fuser_forward(self) -> None: if any(param.device.type == "meta" for param in self.parameters()): self.reset_parameters() - # Check that weights are consistent - dtype = self.weight0.dtype - device = self.weight0.device - weight_requires_grad = self.weight0.requires_grad - weight_tensor_type = type(self.weight0.data) - for group_idx in range(self.num_groups): - weight = getattr(self, f"weight{group_idx}") - if weight.dtype != dtype: - raise RuntimeError( - f"Weight {group_idx} has invalid dtype (expected {dtype}, got {weight.dtype})." - ) - if not devices_match(weight.device, device): - raise RuntimeError( - f"Weight {group_idx} has invalid device " - f"(expected {device}, got {weight.device})." - ) - if weight.requires_grad != weight_requires_grad: - raise RuntimeError( - f"Weight {group_idx} has requires_grad={weight.requires_grad}, " - f"but expected requires_grad={weight_requires_grad}." - ) - if type(weight.data) != weight_tensor_type: # pylint: disable=unidiomatic-typecheck - raise RuntimeError( - f"Weight {group_idx} has invalid tensor type " - f"(expected {weight_tensor_type.__name__}, " - f"got {type(weight.data).__name__})." - ) + # Check that all weight params are consistent + if not self.single_grouped_weight: + dtype = self.weight0.dtype + device = self.weight0.device + weight_requires_grad = self.weight0.requires_grad + weight_tensor_type = type(self.weight0.data) + for group_idx in range(self.num_groups): + weight = getattr(self, f"weight{group_idx}") + if weight.dtype != dtype: + raise RuntimeError( + f"Weight {group_idx} has invalid dtype (expected {dtype}, got" + f" {weight.dtype})." + ) + if not devices_match(weight.device, device): + raise RuntimeError( + f"Weight {group_idx} has invalid device " + f"(expected {device}, got {weight.device})." + ) + if weight.requires_grad != weight_requires_grad: + raise RuntimeError( + f"Weight {group_idx} has requires_grad={weight.requires_grad}, " + f"but expected requires_grad={weight_requires_grad}." + ) + if type(weight.data) != weight_tensor_type: # pylint: disable=unidiomatic-typecheck + raise RuntimeError( + f"Weight {group_idx} has invalid tensor type " + f"(expected {weight_tensor_type.__name__}, " + f"got {type(weight.data).__name__})." + ) + else: + dtype = self.weight.dtype + device = self.weight.device + weight_requires_grad = self.weight.requires_grad + weight_tensor_type = type(self.weight.data) # Check that biases are consistent - for group_idx in range(self.num_groups): - bias = getattr(self, f"bias{group_idx}") - if self.has_bias: - if bias is None: - raise RuntimeError(f"Expected biases, but bias {group_idx} is uninitialized") + if self.has_bias: + if self.single_grouped_bias: + bias = self.bias if bias.dtype != dtype: raise RuntimeError( - f"Bias {group_idx} has invalid dtype (expected {dtype}, got {bias.dtype})." + f"Bias has invalid dtype (expected {dtype}, got {bias.dtype})." ) if not devices_match(bias.device, device): raise RuntimeError( - f"Bias {group_idx} has invalid device " - f"(expected {device}, got {bias.device})." + f"Bias has invalid device (expected {device}, got {bias.device})." ) if bias.requires_grad != weight_requires_grad: raise RuntimeError( - f"Bias {group_idx} has requires_grad={bias.requires_grad}, " + f"Bias has requires_grad={bias.requires_grad}, " f"but expected requires_grad={weight_requires_grad}." ) else: - if bias is not None: - raise RuntimeError(f"Expected no biases, but bias {group_idx} is initialized") + for group_idx in range(self.num_groups): + bias = getattr(self, f"bias{group_idx}") + if bias is None: + raise RuntimeError( + f"Expected biases, but bias {group_idx} is uninitialized" + ) + if bias.dtype != dtype: + raise RuntimeError( + f"Bias {group_idx} has invalid dtype (expected {dtype}, got" + f" {bias.dtype})." + ) + if not devices_match(bias.device, device): + raise RuntimeError( + f"Bias {group_idx} has invalid device " + f"(expected {device}, got {bias.device})." + ) + if bias.requires_grad != weight_requires_grad: + raise RuntimeError( + f"Bias {group_idx} has requires_grad={bias.requires_grad}, " + f"but expected requires_grad={weight_requires_grad}." + ) + else: + if self.single_grouped_bias: + if getattr(self, "bias", None) is not None: + raise RuntimeError("Expected no biases, but grouped `bias` is registered") + else: + for group_idx in range(self.num_groups): + bias = getattr(self, f"bias{group_idx}") + if bias is not None: + raise RuntimeError( + f"Expected no biases, but bias {group_idx} is initialized" + ) def pre_fuser_forward(self, *, requires_grad: bool) -> None: super().pre_fuser_forward(requires_grad=requires_grad) if FP8GlobalStateManager.is_fp8_enabled(): # Assume weights have consistent grad requirement - weight_requires_grad = requires_grad and self.weight0.requires_grad + weight_requires_grad = ( + self.weight.requires_grad + if self.single_grouped_weight + else self.weight0.requires_grad + ) + weight_requires_grad = requires_grad and weight_requires_grad # Configure quantizer usages # Note: We cache the quantized input for backward pass, @@ -419,13 +606,17 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: # Make sure weight param has correct quantizer weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) weight_quantizer.internal = False - getattr(self, f"weight{group_idx}").update_quantizer(weight_quantizer.copy()) + if self.single_grouped_weight: + self.weight.quantizer = weight_quantizer.copy() + else: + getattr(self, f"weight{group_idx}").update_quantizer(weight_quantizer.copy()) else: # Use internal tensors if quantized weights will not be # exposed externally weight_quantizer.internal = ( not FP8GlobalStateManager.with_fp8_parameters() and not getattr(self, "_with_quantized_weight", False) + and not self.single_grouped_weight ) # Recipe-specific configuration @@ -472,12 +663,19 @@ def fuser_forward( ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: num_groups = self.num_groups has_bias = self.has_bias - device = self.weight0.device + weight_param = self.weight if self.single_grouped_weight else self.weight0 + device = weight_param.device + + if self._accumulate_into_main_grad: + if not hasattr(weight_param, "main_grad"): + raise RuntimeError("MAIN GRAD NOT FOUND") + if weight_param.main_grad is None: + raise RuntimeError("MAIN GRAD IS NONE") # Check which grads are required ctx = basic_op_ctxs[0] input_requires_grad = ctx.requires_grad - weight_requires_grad = ctx.requires_grad and self.weight0.requires_grad + weight_requires_grad = ctx.requires_grad and weight_param.requires_grad # Quantizers input_quantizers = [None] * num_groups @@ -494,7 +692,7 @@ def fuser_forward( if torch.is_autocast_enabled(): dtype = torch.get_autocast_dtype("cuda") else: - dtype = self.weight0.dtype + dtype = weight_param.dtype # Extract split sizes from extra input split_sizes = basic_op_extra_inputs[0][0] @@ -503,10 +701,24 @@ def fuser_forward( raise ValueError(f"Expected {num_groups} splits, but got {len(split_sizes_int)}.") # Extract params - weights = [getattr(self, f"weight{idx}") for idx in range(num_groups)] + if self.single_grouped_weight: + weights = self.weight.quantized_tensors + if weights is None: + weights = self.weight.split_into_quantized_tensors() + else: + weights = [getattr(self, f"weight{idx}") for idx in range(num_groups)] bs = None if has_bias: - bs = [maybe_dequantize(getattr(self, f"bias{idx}"), dtype) for idx in range(num_groups)] + if self.single_grouped_bias: + bias_parts = self.bias.quantized_tensors + if bias_parts is None: + bias_parts = self.bias.split_into_quantized_tensors() + bs = [maybe_dequantize(p.reshape(-1), dtype) for p in bias_parts] + else: + bs = [ + maybe_dequantize(getattr(self, f"bias{idx}"), dtype) + for idx in range(num_groups) + ] # Convert weight dtype if needed ws = [] @@ -589,7 +801,8 @@ def fuser_backward( ]: num_groups = self.num_groups has_bias = self.has_bias - device = self.weight0.device + weight_param = self.weight if self.single_grouped_weight else self.weight0 + device = weight_param.device # Saved tensors from forward pass ctx = basic_op_ctxs[0] @@ -628,14 +841,42 @@ def fuser_backward( # Megatron-LM wgrad fusion # Note: Get grad tensors from params so we can # accumulate directly into it. - for group_idx in range(num_groups): - weight_param = getattr(self, f"weight{group_idx}") + if self.single_grouped_weight: if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() - grad_weights[group_idx] = weight_param.main_grad - accumulate_into_main_grad = not getattr(self.weight0, "overwrite_main_grad", False) + main_grad = weight_param.main_grad + if isinstance(main_grad, GroupedTensor): + grad_weights = main_grad.quantized_tensors + if grad_weights is None: + grad_weights = main_grad.split_into_quantized_tensors() + else: + # main_grad may be [num_groups, out, in] or a flat buffer. + # Canonicalize to grouped layout before slicing per-group views. + weight_shape = (self.out_features, self.in_features) + grouped_shape = (num_groups, *weight_shape) + if main_grad.shape != grouped_shape: + if main_grad.numel() != math.prod(grouped_shape): + raise RuntimeError( + "GroupedLinear expected grouped weight main_grad to have " + f"shape {grouped_shape} or matching numel, " + f"but got shape {tuple(main_grad.shape)}" + ) + main_grad = main_grad.reshape(grouped_shape) + grad_weights = [main_grad[idx] for idx in range(num_groups)] + accumulate_into_main_grad = not getattr( + weight_param, "overwrite_main_grad", False + ) + else: + for group_idx in range(num_groups): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + grad_weights[group_idx] = weight_param.main_grad + accumulate_into_main_grad = not getattr( + self.weight0, "overwrite_main_grad", False + ) else: - weight_shape = ws[0].size() + weight_shape = (self.out_features, self.in_features) for group_idx in range(num_groups): grad_weights[group_idx] = torch.empty( weight_shape, @@ -668,26 +909,63 @@ def fuser_backward( ) # Perform wgrad GEMMs + delay_wgrad = ( + ctx.weight_requires_grad + and self.wgrad_store is not None + and self.wgrad_store.delay_wgrad_compute() + ) if ctx.weight_requires_grad: - general_grouped_gemm( - xs, - dys, - grad_weights, - [None] * num_groups, # quantization_params - ctx.dtype, - layout="NT", - m_splits=split_sizes_int, - use_split_accumulator=_2X_ACC_WGRAD, - accumulate=accumulate_into_main_grad, - ) + if delay_wgrad: + grouped_gemm_wgrad = functools.partial( + general_grouped_gemm, + quantization_params=[None] * num_groups, + out_dtype=ctx.dtype, + layout="NT", + m_splits=split_sizes_int, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_into_main_grad, + ) + self.wgrad_store.put([xs, dys, grad_weights], grouped_gemm_wgrad) + else: + general_grouped_gemm( + xs, + dys, + grad_weights, + [None] * num_groups, # quantization_params + ctx.dtype, + layout="NT", + m_splits=split_sizes_int, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_into_main_grad, + ) - # Clear input tensors if possible - clear_tensor_data(*xs) + if not delay_wgrad: + clear_tensor_data(*xs) # Megatron-LM wgrad fusion # Note: Return dummy tensor for grad weight if needed. if accumulate_into_main_grad: grad_weights = [None] * num_groups + if self.single_grouped_weight: + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + else: + grad_weight = None + # Be mindful of param registration order. + if has_bias: + if self.single_grouped_bias: + final_bias_grads = torch.stack(grad_biases, dim=0).to(ctx.dtype) + grad_params = [grad_weight, final_bias_grads] + else: + grad_params = grad_biases + [grad_weight] + else: + grad_params = [grad_weight] + return grad_input, [grad_params], [(None,)] for group_idx in range(num_groups): weight_param = getattr(self, f"weight{group_idx}") if hasattr(weight_param, "grad_added_to_main_grad"): @@ -698,5 +976,29 @@ def fuser_backward( zero=getattr(weight_param, "zero_out_wgrad", False), ) - grad_params = grad_weights + grad_biases if has_bias else grad_weights + if self.single_grouped_weight: + grad_weight = None + if ctx.weight_requires_grad: + if delay_wgrad: + grad_weight = None + else: + grad_weight = torch.stack(grad_weights, dim=0) + final_weight_grads = [grad_weight] + else: + if delay_wgrad and ctx.weight_requires_grad: + final_weight_grads = [None] * num_groups + else: + final_weight_grads = grad_weights + + if not has_bias: + grad_params = list(final_weight_grads) + elif self.single_grouped_bias: + final_bias_grads = torch.stack(grad_biases, dim=0).to(ctx.dtype) + grad_params = list(final_weight_grads) + [final_bias_grads] + else: + if self.single_grouped_weight: + grad_params = list(grad_biases) + list(final_weight_grads) + else: + grad_params = list(final_weight_grads) + list(grad_biases) + return grad_input, [grad_params], [(None,)] diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 19608894e0..19a090f121 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -28,3 +28,12 @@ register_backward_fusion(BackwardLinearScale.fuse_backward_ops) register_backward_fusion(BackwardActivationBias.fuse_backward_ops) register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops) + +# Import experimental fusions +# Note: Registration logic is non-trivial, so submodule handles it internally. +from .forward_grouped_mlp import ( # pylint: disable=wrong-import-position + ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, +) +from .backward_grouped_mlp import ( # pylint: disable=wrong-import-position + BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, +) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py new file mode 100644 index 0000000000..21cd4be0bf --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -0,0 +1,673 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused operation for MoE grouped MLP.""" + +from __future__ import annotations +from collections.abc import Callable +import functools +import inspect +import math +from typing import Optional + +import torch + +import transformer_engine_torch as tex +from ...cpp_extensions import ( + general_grouped_gemm_for_grouped_tensor, +) +from ...module.base import get_dummy_wgrad +from ...quantization import Recipe +from ...tensor.grouped_tensor import GroupedTensor +from ...tensor.mxfp8_tensor import MXFP8Quantizer +from ...utils import clear_tensor_data, get_cached_ones_tensor, get_device_compute_capability +from ...constants import MXFP8_BLOCK_SCALING_SIZE +from ..basic import GroupedLinear, ScaledSwiGLU +from ..fuser import register_backward_fusion +from ..op import FusedOperation, FusibleOperation, OperationContext +from .._common import ( + fuse_grouped_mlp_ops, + maybe_dequantize, + validate_grouped_mlp_dims, +) + + +@functools.lru_cache(maxsize=1) +def _dglu_wrapper_has_generate_dbias_arg() -> bool: + """True if cudnn-frontend SM100 dGLU wrapper accepts ``generate_dbias``.""" + try: + from cudnn import grouped_gemm_dglu_wrapper_sm100 # pylint: disable=import-outside-toplevel + except ImportError: + return False + try: + params = inspect.signature(grouped_gemm_dglu_wrapper_sm100).parameters + except (TypeError, ValueError): + return False + return "generate_dbias" in params + + +def _compute_grad_params( + fc_op, + ctx, + num_groups, + weight_shape, + grouped_x, + grouped_dy, + dtype, + device, + bias_grads, + bias_grad_packed, + label="", +): + """Compute weight gradients and build grad_params for a GroupedLinear layer. + Returns the grad_params list in parameter registration order. + """ + + # Allocate grad buffers, determine accumulate flag + accumulate_into_main_grad = False + grouped_wgrad = None + wgrad_output = None + if fc_op.single_grouped_weight: + w_list = [None] + if ctx.weight_requires_grad: + weight_param = fc_op.weight + if fc_op._accumulate_into_main_grad: + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + main_grad = weight_param.main_grad + grouped_shape = (num_groups, *weight_shape) + if main_grad.shape != grouped_shape: + if main_grad.numel() != math.prod(grouped_shape): + raise RuntimeError( + f"Grouped MLP fused backward expected {label} main_grad to have " + f"shape {grouped_shape} or matching numel, " + f"but got shape {tuple(main_grad.shape)}" + ) + try: + main_grad = main_grad.view(grouped_shape) + except RuntimeError as e: + raise RuntimeError( + f"Grouped MLP fused backward requires {label} main_grad to be " + f"viewable as {grouped_shape} without copy, but got shape" + f" {tuple(main_grad.shape)} and stride" + f" {tuple(main_grad.stride())}" + ) from e + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) + if accumulate_into_main_grad: + grouped_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data( + num_tensors=num_groups, + tensor_shape=weight_shape, + rowwise_data=main_grad, + dtype=main_grad.dtype, + ) + + if grouped_wgrad is None: + grouped_wgrad = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_groups, + shapes=[weight_shape] * num_groups, + quantizer=None, + device=device, + dtype=dtype, + ) + wgrad_output = grouped_wgrad + else: + w_list = [None] * num_groups + if ctx.weight_requires_grad: + if fc_op._accumulate_into_main_grad: + for idx in range(num_groups): + wp = getattr(fc_op, f"weight{idx}") + if hasattr(wp, "__fsdp_param__"): + wp.main_grad = wp.get_main_grad() + w_list[idx] = wp.main_grad + accumulate_into_main_grad = not getattr(fc_op.weight0, "overwrite_main_grad", False) + else: + for idx in range(num_groups): + w_list[idx] = torch.empty(weight_shape, dtype=dtype, device=device) + wgrad_output = w_list + + if ctx.weight_requires_grad: + # Launch or defer the GEMM + delay_wgrad = fc_op.wgrad_store is not None and fc_op.wgrad_store.delay_wgrad_compute() + gemm_fn = functools.partial( + general_grouped_gemm_for_grouped_tensor, + layout="NT", + accumulate=accumulate_into_main_grad, + ) + if delay_wgrad: + fc_op.wgrad_store.put([grouped_x, grouped_dy, wgrad_output], gemm_fn) + else: + gemm_fn(grouped_x, grouped_dy, wgrad_output) + + # Extract results, mark accumulated if needed + if fc_op.single_grouped_weight: + packed_wgrad = None + if not delay_wgrad: + packed_wgrad = grouped_wgrad.rowwise_data.view(num_groups, *weight_shape) + if accumulate_into_main_grad and hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + packed_wgrad = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + w_list = [packed_wgrad] + else: + if delay_wgrad: + w_list = list(w_list) if accumulate_into_main_grad else [None] * num_groups + if accumulate_into_main_grad: + for idx in range(num_groups): + wp = getattr(fc_op, f"weight{idx}") + if hasattr(wp, "grad_added_to_main_grad"): + wp.grad_added_to_main_grad = True + w_list[idx] = get_dummy_wgrad( + list(wp.size()), + wp.dtype, + zero=getattr(wp, "zero_out_wgrad", False), + ) + + # Assemble grad_params in parameter registration order. + if not fc_op.has_bias: + return w_list + + if fc_op.single_grouped_bias: + return w_list + [bias_grad_packed] + + bias_list = bias_grads if bias_grads is not None else [None] * num_groups + if fc_op.single_grouped_weight: + return bias_list + w_list + return w_list + bias_list + + +class BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8(FusedOperation): + """Fused op for MXFP8 GroupedLinear + ScaledSwiGLU + GroupedLinear + + Uses experimental CuTe DSL kernel from cuDNN front-end. + + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_dglu_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, GLU activation backward, and scale grad.""" + from cudnn import grouped_gemm_dglu_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_dglu_wrapper_sm100 + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_quant_kernel(cls) -> Callable: + """Grouped GEMM quant kernel for block-scaled inputs.""" + from cudnn import grouped_gemm_quant_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_quant_wrapper_sm100 + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + """Whether this fused operation is supported on the current system.""" + if get_device_compute_capability()[0] != 10: + return False + try: + cls.grouped_gemm_dglu_kernel() + cls.grouped_gemm_quant_kernel() + except ImportError: + return False + return True + + @classmethod + def is_fc1_bias_supported(cls) -> bool: + """Whether cudnn-frontend exposes ``generate_dbias`` on the dGLU SM100 wrapper (FC1 bias grad only).""" + if not cls.is_supported(): + return False + return _dglu_wrapper_has_generate_dbias_arg() + + def __init__( + self, + *, + fc1: GroupedLinear, + swiglu: ScaledSwiGLU, + fc2: GroupedLinear, + ) -> None: + super().__init__((fc1, swiglu, fc2)) + if not self.is_supported(): + self.grouped_gemm_dglu_kernel() # Try triggering import error + raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") + validate_grouped_mlp_dims(fc1, swiglu, fc2) + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + **unused, # pylint: disable=unused-argument + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + + # Get basic operations + fc1_op, _, fc2_op = self.basic_ops + fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs + + # Tensor properties + fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) + fc2_weight_shape = (fc2_op.out_features, fc2_op.in_features) + grad_output = grad_output.reshape(-1, fc2_weight_shape[0]) + out_shape = list(grad_output.size()) + num_groups = fc1_op.num_groups + fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0 + device = fc1_weight_param.device + dtype = fc1_ctx.dtype + + # Saved tensors from FC1 forward + saved_tensors = fc1_ctx.saved_tensors + split_sizes, split_points, saved_tensors = ( + saved_tensors[0], + saved_tensors[1], + saved_tensors[2:], + ) + + if fc1_op.single_grouped_weight: + grouped_fc1_weight, saved_tensors = saved_tensors[0], saved_tensors[1:] + else: + grouped_fc1_weight, saved_tensors = ( + saved_tensors[:num_groups], + saved_tensors[num_groups:], + ) + + ( + fc1_x_col_data, + fc1_x_col_scale, + fc1_x_tensor_offsets, + ), saved_tensors = ( + saved_tensors[:3], + saved_tensors[3:], + ) + + # Saved tensors from scaled SwiGLU forward + swiglu_in, scales = swiglu_ctx.saved_tensors + + # Saved tensors from FC2 forward + saved_tensors = fc2_ctx.saved_tensors + _, saved_tensors = saved_tensors[0], saved_tensors[1:] # Assume same split sizes as FC1 + if fc2_op.single_grouped_weight: + grouped_fc2_weight, saved_tensors = saved_tensors[0], saved_tensors[1:] + else: + grouped_fc2_weight, saved_tensors = ( + saved_tensors[:num_groups], + saved_tensors[num_groups:], + ) + + ( + fc2_x_col_data, + fc2_x_col_scale, + fc2_x_tensor_offsets, + ), saved_tensors = ( + saved_tensors[:3], + saved_tensors[3:], + ) + + # Group splits + if int(split_sizes.numel()) != num_groups: + raise ValueError(f"Expected {num_groups} splits, but got {int(split_sizes.numel())}.") + split_sizes = split_sizes.to(dtype=torch.int64, device=device) + split_points = split_points.to(dtype=torch.int, device=device) + + grouped_fc1_x = None + if fc1_ctx.weight_requires_grad: + grouped_fc1_x = GroupedTensor( + shape=(out_shape[0], fc1_weight_shape[1]), + dtype=dtype, + num_tensors=num_groups, + quantizer=fc1_ctx.input_quantizer, + columnwise_data=fc1_x_col_data, + columnwise_scale_inv=fc1_x_col_scale, + first_dims=split_sizes, + tensor_offsets=fc1_x_tensor_offsets, + with_gemm_swizzled_scales=True, + ) + + grouped_fc2_x = None + if fc2_ctx.weight_requires_grad: + grouped_fc2_x = GroupedTensor( + shape=(out_shape[0], fc2_weight_shape[1]), + dtype=dtype, + num_tensors=num_groups, + quantizer=fc2_ctx.input_quantizer, + columnwise_data=fc2_x_col_data, + columnwise_scale_inv=fc2_x_col_scale, + first_dims=split_sizes, + tensor_offsets=fc2_x_tensor_offsets, + with_gemm_swizzled_scales=True, + ) + + # Split grad output tensor and convert dtypes if needed + fc2_ctx.grad_output_quantizer.set_usage( + rowwise=True, columnwise=fc2_ctx.weight_requires_grad + ) + fc2_ctx.grad_output_quantizer.optimize_for_gemm = True + output_fc2_dbias = fc2_op.has_bias + fc2_dbias_packed = None + if ( + not output_fc2_dbias + and isinstance(grad_output, GroupedTensor) + and isinstance(getattr(grad_output, "quantizer", None), MXFP8Quantizer) + ): + grouped_fc2_dy = grad_output + else: + fc2_dy = maybe_dequantize(grad_output, dtype) + gq_ret = tex.group_quantize( + fc2_dy, + fc2_ctx.grad_output_quantizer, + num_groups, + split_sizes, + output_fc2_dbias, + ) + if output_fc2_dbias: + grouped_fc2_dy, fc2_dbias_packed = gq_ret + else: + grouped_fc2_dy = gq_ret + + fc2_bias_grads: Optional[list[Optional[torch.Tensor]]] = None + fc2_bias_grad_packed: Optional[torch.Tensor] = None + if fc2_dbias_packed is not None: + if fc2_op.single_grouped_bias: + fc2_bias_grad_packed = fc2_dbias_packed.to(dtype=dtype) + else: + fc2_bias_grads = [ + fc2_dbias_packed[idx].to(dtype=dtype) for idx in range(num_groups) + ] + + # Pack data tensors + # Note: Fused kernel expects tensor with non-contiguous + # logical dims. + # Data actual shape: (1, sum(m), k) + # Scale actual shape: (1, sum(m)/128, k/128, 32 (block row), + # 4 (block row), 4 (block col)) + # Data logical shape: (sum(m), k, 1) + # Scale logical shape: (32 (block row), 4 (block row), + # sum(m)/128, 4 (block col), k/128, 1) + fc2_dy_data = grouped_fc2_dy.rowwise_data.view(out_shape[0], out_shape[1]) + fc2_dy_data = fc2_dy_data.view(dtype=torch.float8_e4m3fn) + fc2_dy_data = fc2_dy_data.unsqueeze(0).permute(1, 2, 0) + fc2_dy_scales = grouped_fc2_dy.scale_inv + fc2_dy_scales = fc2_dy_scales.view(dtype=torch.float8_e8m0fnu) + fc2_dy_scales = fc2_dy_scales.view( + 1, + out_shape[0] // 128, + out_shape[1] // 128, + MXFP8_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc2_dy_scales = fc2_dy_scales.permute(3, 4, 1, 5, 2, 0) + + # Kernel scaling factors + alpha_tensor = get_cached_ones_tensor(num_groups, dtype, device) + norm_const_tensor = get_cached_ones_tensor(1, dtype, device) + current_stream = torch.cuda.current_stream().cuda_stream + + prob_tensor = scales.detach().to(dtype=torch.float32).reshape(-1, 1, 1) + dprob_tensor = torch.zeros_like(prob_tensor) + + fc2_dglu_kwargs = { + "a_tensor": fc2_dy_data, + "c_tensor": swiglu_in.unsqueeze(0).permute(1, 2, 0), + "sfa_tensor": fc2_dy_scales, + "padded_offsets": split_points, + "alpha_tensor": alpha_tensor, + "beta_tensor": alpha_tensor, + "prob_tensor": prob_tensor, + "dprob_tensor": dprob_tensor, + "generate_dbias": fc1_op.has_bias, + "norm_const_tensor": norm_const_tensor, + "d_dtype": torch.float8_e4m3fn, + "cd_major": "n", + "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, + "current_stream": current_stream, + "discrete_col_sfd": True, + "act_func": "dswiglu", + "use_dynamic_sched": True, + } + + if fc2_op.single_grouped_weight: + # Clone and swizzle scales for GEMM + fc2_weight_for_gemm = grouped_fc2_weight.copy() + tex.grouped_swizzle_for_gemm(fc2_weight_for_gemm, rowwise=False, columnwise=True) + # Pack weight tensors for stacked kernel + # Data actual shape: (num_groups, k, n) + # Data logical shape: (n, k, num_groups) + fc2_w_data = fc2_weight_for_gemm.columnwise_data + fc2_w_data = fc2_w_data.view(dtype=torch.float8_e4m3fn) + fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_shape[1]) + fc2_w_data = fc2_w_data.permute(2, 1, 0) + fc2_w_scales = fc2_weight_for_gemm.columnwise_scale_inv.view(dtype=torch.float8_e8m0fnu) + fc2_w_scales = fc2_w_scales.view( + num_groups, + fc2_weight_shape[1] // 128, + fc2_weight_shape[0] // 128, + MXFP8_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc2_w_scales = fc2_w_scales.permute(3, 4, 1, 5, 2, 0) + + fc2_dglu_kwargs["b_tensor"] = fc2_w_data + fc2_dglu_kwargs["sfb_tensor"] = fc2_w_scales + else: + fc2_b_ptrs, fc2_sfb_ptrs, _fc2_sw = tex.get_device_pointer_for_data_and_scales( + [w._columnwise_data for w in grouped_fc2_weight], + [w._columnwise_scale_inv for w in grouped_fc2_weight], + swizzle=True, + rowwise=False, + data_dtype=grouped_fc2_weight[0]._fp8_dtype, + ) + fc2_dglu_kwargs["b_ptrs"] = fc2_b_ptrs + fc2_dglu_kwargs["sfb_ptrs"] = fc2_sfb_ptrs + fc2_dglu_kwargs["n"] = fc2_weight_shape[1] + fc2_dglu_kwargs["b_dtype"] = torch.float8_e4m3fn + fc2_dglu_kwargs["b_major"] = "n" + + fc2_dgrad_kernel_out = self.grouped_gemm_dglu_kernel()(**fc2_dglu_kwargs) + + fc1_dy_row_data = fc2_dgrad_kernel_out["d_row_tensor"] + fc1_dy_row_data = fc1_dy_row_data.view(out_shape[0], fc1_weight_shape[0]).contiguous() + fc1_dy_row_scale = fc2_dgrad_kernel_out["sfd_row_tensor"] + fc1_dy_col_data = fc2_dgrad_kernel_out["d_col_tensor"] + fc1_dy_col_data = fc1_dy_col_data.view(out_shape[0], fc1_weight_shape[0]).contiguous() + fc1_dy_col_scale = fc2_dgrad_kernel_out["sfd_col_tensor"] + grad_scales = fc2_dgrad_kernel_out["dprob_tensor"] + grad_scales = grad_scales.view(-1).to(dtype=dtype) + + fc1_bias_grads: Optional[list[Optional[torch.Tensor]]] = None + fc1_bias_grad_packed: Optional[torch.Tensor] = None + if fc1_op.has_bias: + dbias_t = fc2_dgrad_kernel_out["dbias_tensor"] + if dbias_t is not None: + dbias_2d = dbias_t.squeeze(-1) + if fc1_op.single_grouped_bias: + fc1_bias_grad_packed = dbias_2d.to(dtype=dtype) + else: + fc1_bias_grads = [ + dbias_2d[group_idx].to(dtype=dtype) for group_idx in range(num_groups) + ] + + # FC1 grad output for dgrad and wgrad GEMMs + fc1_dy_tensor_offsets = fc1_ctx.base_split_offsets * fc1_weight_shape[0] + grouped_fc1_dy = GroupedTensor( + shape=(out_shape[0], fc1_weight_shape[0]), + dtype=dtype, + num_tensors=num_groups, + quantizer=fc1_ctx.grad_output_quantizer, + data=fc1_dy_row_data, + columnwise_data=fc1_dy_col_data, + scale_inv=fc1_dy_row_scale, + columnwise_scale_inv=fc1_dy_col_scale, + first_dims=split_sizes, + tensor_offsets=fc1_dy_tensor_offsets, + with_gemm_swizzled_scales=True, + ) + + # FC2 wgrad GEMM + fc2_grad_params = _compute_grad_params( + fc_op=fc2_op, + ctx=fc2_ctx, + num_groups=num_groups, + weight_shape=fc2_weight_shape, + grouped_x=grouped_fc2_x, + grouped_dy=grouped_fc2_dy, + dtype=dtype, + device=device, + bias_grads=fc2_bias_grads, + bias_grad_packed=fc2_bias_grad_packed, + label="FC2", + ) + + # Clear FC2 input tensor if possible + if grouped_fc2_x is not None and not ( + fc2_ctx.weight_requires_grad + and fc2_op.wgrad_store is not None + and fc2_op.wgrad_store.delay_wgrad_compute() + ): + clear_tensor_data( + grouped_fc2_x.data, + grouped_fc2_x.columnwise_data, + grouped_fc2_x.scale_inv, + grouped_fc2_x.columnwise_scale_inv, + ) + + # FC1 dgrad GEMM + grad_input = None + if fc1_ctx.input_requires_grad: + in_shape = out_shape[:-1] + [fc1_weight_shape[1]] + + fc1_dgrad_a_data = fc2_dgrad_kernel_out["d_row_tensor"] + fc1_dgrad_a_scales = fc2_dgrad_kernel_out["sfd_row_tensor"] + + fc1_dgrad_kwargs = { + "a_tensor": fc1_dgrad_a_data, + "sfa_tensor": fc1_dgrad_a_scales, + "padded_offsets": split_points, + "alpha_tensor": alpha_tensor.float(), + "norm_const_tensor": None, + "prob_tensor": torch.ones((out_shape[0], 1, 1), dtype=torch.float32, device=device), + "acc_dtype": torch.float32, + "c_dtype": dtype, + "d_dtype": dtype, + "cd_major": "n", + "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, + "current_stream": current_stream, + "discrete_col_sfd": True, + "use_dynamic_sched": True, + } + + if fc1_op.single_grouped_weight: + # Clone and swizzle scales for GEMM + fc1_weight_for_gemm = grouped_fc1_weight.copy() + tex.grouped_swizzle_for_gemm(fc1_weight_for_gemm, rowwise=False, columnwise=True) + + fc1_w_data = fc1_weight_for_gemm.columnwise_data + fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) + fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1]) + fc1_w_data = fc1_w_data.permute(2, 1, 0) + fc1_w_scales = fc1_weight_for_gemm.columnwise_scale_inv.view( + dtype=torch.float8_e8m0fnu + ) + fc1_w_scales = fc1_w_scales.view( + num_groups, + fc1_weight_shape[1] // 128, + fc1_weight_shape[0] // 128, + MXFP8_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) + + fc1_dgrad_kwargs["b_tensor"] = fc1_w_data + fc1_dgrad_kwargs["sfb_tensor"] = fc1_w_scales + else: + fc1_b_ptrs, fc1_sfb_ptrs, _ = tex.get_device_pointer_for_data_and_scales( + [w._columnwise_data for w in grouped_fc1_weight], + [w._columnwise_scale_inv for w in grouped_fc1_weight], + swizzle=True, + rowwise=False, + data_dtype=grouped_fc1_weight[0]._fp8_dtype, + ) + + fc1_dgrad_kwargs["b_ptrs"] = fc1_b_ptrs + fc1_dgrad_kwargs["sfb_ptrs"] = fc1_sfb_ptrs + fc1_dgrad_kwargs["n"] = fc1_weight_shape[1] + fc1_dgrad_kwargs["b_dtype"] = torch.float8_e4m3fn + fc1_dgrad_kwargs["b_major"] = "n" + + fc1_dgrad_kernel_out = self.grouped_gemm_quant_kernel()(**fc1_dgrad_kwargs) + grad_input = fc1_dgrad_kernel_out["d_tensor"].view(in_shape) + + # FC1 wgrad GEMM + fc1_grad_params = _compute_grad_params( + fc_op=fc1_op, + ctx=fc1_ctx, + num_groups=num_groups, + weight_shape=fc1_weight_shape, + grouped_x=grouped_fc1_x, + grouped_dy=grouped_fc1_dy, + dtype=dtype, + device=device, + bias_grads=fc1_bias_grads, + bias_grad_packed=fc1_bias_grad_packed, + label="FC1", + ) + + # Clear FC1 input tensor if possible + if grouped_fc1_x is not None and not ( + fc1_ctx.weight_requires_grad + and fc1_op.wgrad_store is not None + and fc1_op.wgrad_store.delay_wgrad_compute() + ): + clear_tensor_data( + grouped_fc1_x.data, + grouped_fc1_x.columnwise_data, + grouped_fc1_x.scale_inv, + grouped_fc1_x.columnwise_scale_inv, + ) + + return ( + grad_input, + [fc1_grad_params, (), fc2_grad_params], + [(None,), (grad_scales,), (None,)], + ) + + +def fuse_backward_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + recipe : Recipe, optional + Quantization recipe. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + return fuse_grouped_mlp_ops( + ops, + recipe=recipe, + fused_op_cls=BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, + ) + + +# Register fusion if available +if BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): + register_backward_fusion(fuse_backward_ops, prepend=True) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py new file mode 100644 index 0000000000..29b204cd67 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -0,0 +1,570 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused operation for MoE grouped MLP.""" + +from __future__ import annotations +from collections.abc import Callable, Iterable +import functools +import inspect +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...quantization import Recipe +from ...tensor import Quantizer +from ...utils import get_cached_ones_tensor, get_device_compute_capability, mark_grouped_tensor +from ...tensor.grouped_tensor import GroupedTensor +from ...tensor.mxfp8_tensor import MXFP8Quantizer +from ...constants import MXFP8_BLOCK_SCALING_SIZE +from ..basic import GroupedLinear, ScaledSwiGLU +from ..fuser import register_forward_fusion +from ..op import FusedOperation, FusibleOperation, OperationContext +from .._common import ( + fuse_grouped_mlp_ops, + is_quantized_tensor, + maybe_dequantize, + validate_grouped_mlp_dims, +) + + +def _pack_grouped_linear_bias_for_cudnn(linear_op: GroupedLinear) -> Optional[torch.Tensor]: + """Bias layout expected by cuDNN grouped GEMM: shape (n, num_groups), stride (1, n).""" + if not linear_op.has_bias: + return None + num_groups = linear_op.num_groups + grouped_bias = getattr(linear_op, "bias", None) + if grouped_bias is not None: + packed = grouped_bias.rowwise_data.view(num_groups, -1) + return packed.transpose(0, 1) + rows = [getattr(linear_op, f"bias{group_idx}") for group_idx in range(num_groups)] + # stack to [num_groups, n] but cuDNN expects [n, num_groups] with stride [1, n]. + return torch.stack(rows, dim=0).transpose(0, 1) + + +class ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8(FusedOperation): + """Fused op for MXFP8 GroupedLinear + ScaledSwiGLU + GroupedLinear + + Uses experimental CuTe DSL kernel from cuDNN front-end. + + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_glu_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, GLU activation, and post-multiplication.""" + from cudnn import grouped_gemm_glu_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_glu_wrapper_sm100 + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_quant_kernel(cls) -> Callable: + """Grouped GEMM quant kernel for block-scaled inputs.""" + from cudnn import grouped_gemm_quant_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_quant_wrapper_sm100 + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + """Whether this fused operation is supported on the current system.""" + if get_device_compute_capability()[0] != 10: + return False + try: + cls.grouped_gemm_glu_kernel() + cls.grouped_gemm_quant_kernel() + except ImportError: + return False + return True + + @classmethod + @functools.lru_cache(maxsize=1) + def is_fc1_bias_supported(cls) -> bool: + """Whether cudnn-frontend exposes ``bias_tensor`` on the grouped GEMM GLU SM100 wrapper (FC1).""" + if not cls.is_supported(): + return False + try: + from cudnn import ( + grouped_gemm_glu_wrapper_sm100, + ) # pylint: disable=import-outside-toplevel + except ImportError: + return False + try: + params = inspect.signature(grouped_gemm_glu_wrapper_sm100).parameters + except (TypeError, ValueError): + return False + return "bias_tensor" in params + + @classmethod + @functools.lru_cache(maxsize=1) + def is_fc2_bias_supported(cls) -> bool: + """Whether cudnn-frontend exposes ``bias_tensor`` on the grouped GEMM Quant SM100 wrapper (FC2).""" + if not cls.is_supported(): + return False + try: + from cudnn import ( + grouped_gemm_quant_wrapper_sm100, + ) # pylint: disable=import-outside-toplevel + except ImportError: + return False + try: + params = inspect.signature(grouped_gemm_quant_wrapper_sm100).parameters + except (TypeError, ValueError): + return False + return "bias_tensor" in params + + def __init__( + self, + *, + fc1: GroupedLinear, + swiglu: ScaledSwiGLU, + fc2: GroupedLinear, + ) -> None: + super().__init__((fc1, swiglu, fc2)) + if not self.is_supported(): + self.grouped_gemm_glu_kernel() # Try triggering import error + raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") + validate_grouped_mlp_dims(fc1, swiglu, fc2) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + # Get basic operations + fc1_op, _, fc2_op = self.basic_ops + fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs + + # Tensor properties + fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) + fc2_weight_shape = (fc2_op.out_features, fc2_op.in_features) + input_ = input_.reshape(-1, fc1_weight_shape[1]) + in_shape = list(input_.size()) + + num_groups = fc1_op.num_groups + fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0 + fc2_weight_param = fc2_op.weight if fc2_op.single_grouped_weight else fc2_op.weight0 + device = fc1_weight_param.device + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = fc1_weight_param.dtype + + # Check which grads are required + requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) + input_requires_grad = requires_grad + weight_requires_grad = requires_grad and ( + fc1_weight_param.requires_grad or fc2_weight_param.requires_grad + ) + + # Quantizers + fc1_input_quantizer = fc1_op.get_quantizer("forward", 0) + fc1_weight_quantizer = fc1_op.get_quantizer("forward", 1) + fc1_grad_output_quantizer = fc1_op.get_quantizer("backward", 0) + fc2_input_quantizer = fc2_op.get_quantizer("forward", 0) + fc2_weight_quantizer = fc2_op.get_quantizer("forward", 1) + fc2_grad_output_quantizer = fc2_op.get_quantizer("backward", 0) + + # Extract split sizes from extra input + fc1_split_sizes = basic_op_extra_inputs[0][0] + fc2_split_sizes = basic_op_extra_inputs[2][0] + if ( + fc1_split_sizes.size() != fc2_split_sizes.size() + or fc1_split_sizes.data_ptr() != fc2_split_sizes.data_ptr() + ): + raise RuntimeError( + f"{self.__class__.__name__} got different split points for FC1 and FC2." + ) + split_sizes = fc1_split_sizes + if int(split_sizes.numel()) != num_groups: + raise ValueError(f"Expected {num_groups} splits, but got {int(split_sizes.numel())}.") + split_sizes = split_sizes.to(dtype=torch.int64, device=device) + split_points = torch.cumsum(split_sizes, 0, dtype=torch.int) + split_points_offsets = torch.cumsum(split_sizes, 0) + base_offsets = torch.cat( + [ + torch.zeros(1, device=split_sizes.device, dtype=split_sizes.dtype), + split_points_offsets, + ] + ) + fc1_x_tensor_offsets = base_offsets * fc1_weight_shape[1] + fc2_x_tensor_offsets = base_offsets * fc2_weight_shape[1] + + # Extract post-scales from extra input + scales = basic_op_extra_inputs[1][0] + + # Prepare FC1 grouped weight tensor for fused kernels. + # - single_grouped_weight=True: op.weight is already a GroupedTensor + # - single_grouped_weight=False: cute DSL kernel works with discrete weight tensors + # as long as host pointers for addresses are packed as contiguous device tensor. + if fc1_op.single_grouped_weight: + if not isinstance(fc1_op.weight, GroupedTensor): + raise RuntimeError( + "FC1 expected GroupedTensor weight with single_grouped_weight=True." + ) + if fc1_op.weight.quantizer is not None: + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + fc1_op.weight.quantizer = fc1_weight_quantizer + grouped_fc1_weight = fc1_op.weight + else: + if fc1_op.weight.rowwise_data is None: + raise RuntimeError("FC1 grouped weight has no rowwise_data to quantize.") + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + grouped_fc1_weight = tex.group_quantize( + fc1_op.weight.rowwise_data.view(fc1_op.weight.logical_shape), + fc1_weight_quantizer, + num_groups, + None, + ) + else: + fc1_weights = [getattr(fc1_op, f"weight{idx}") for idx in range(num_groups)] + quantized_fc1_weights = [] + for idx, weight in enumerate(fc1_weights): + quantizer = fc1_op.get_quantizer("forward", 2 * idx + 1) + if not is_quantized_tensor(weight): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + quantized_fc1_weights.append(quantizer(weight)) + else: + quantized_fc1_weights.append(weight) + grouped_fc1_weight = quantized_fc1_weights + + # Prepare FC2 grouped weight tensor for fused kernels. + if fc2_op.single_grouped_weight: + if not isinstance(fc2_op.weight, GroupedTensor): + raise RuntimeError( + "FC2 expected GroupedTensor weight with single_grouped_weight=True." + ) + if fc2_op.weight.quantizer is not None: + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + fc2_op.weight.quantizer = fc2_weight_quantizer + grouped_fc2_weight = fc2_op.weight + else: + if fc2_op.weight.rowwise_data is None: + raise RuntimeError("FC2 grouped weight has no rowwise_data to quantize.") + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + grouped_fc2_weight = tex.group_quantize( + fc2_op.weight.rowwise_data.view(fc2_op.weight.logical_shape), + fc2_weight_quantizer, + num_groups, + None, + ) + else: + fc2_weights = [getattr(fc2_op, f"weight{idx}") for idx in range(num_groups)] + quantized_fc2_weights = [] + for idx, weight in enumerate(fc2_weights): + quantizer = fc2_op.get_quantizer("forward", 2 * idx + 1) + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + if not is_quantized_tensor(weight): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + quantized_fc2_weights.append(quantizer(weight)) + else: + quantized_fc2_weights.append(weight) + grouped_fc2_weight = quantized_fc2_weights + + # Some wrapper-copy paths may drop grouped storage metadata; enforce defaults. + if getattr(grouped_fc1_weight, "_with_gemm_swizzled_scales", None) is None and isinstance( + grouped_fc1_weight, GroupedTensor + ): + grouped_fc1_weight._with_gemm_swizzled_scales = False + if getattr(grouped_fc2_weight, "_with_gemm_swizzled_scales", None) is None and isinstance( + grouped_fc2_weight, GroupedTensor + ): + grouped_fc2_weight._with_gemm_swizzled_scales = False + + # Group-quantize input tensor and convert dtypes if needed + fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + fc1_input_quantizer.optimize_for_gemm = True + if isinstance(input_, GroupedTensor) and isinstance( + getattr(input_, "quantizer", None), MXFP8Quantizer + ): + grouped_fc1_x = input_ + else: + fc1_x = maybe_dequantize(input_, dtype) + grouped_fc1_x = tex.group_quantize(fc1_x, fc1_input_quantizer, num_groups, split_sizes) + + # Pack data tensors + # Note: Fused kernel expects tensor with non-contiguous + # logical dims. + # Data actual shape: (1, sum(m), k) + # Scale actual shape: (1, sum(m)/128, k/128, 32 (block row), + # 4 (block row), 4 (block col)) + # Data logical shape: (sum(m), k, 1) + # Scale logical shape: (32 (block row), 4 (block row), + # sum(m)/128, 4 (block col), k/128, 1) + fc1_x_data = grouped_fc1_x.rowwise_data.view(in_shape[0], in_shape[1]) + fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn) + fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) + fc1_x_scales = grouped_fc1_x.scale_inv + fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) + fc1_x_scales = fc1_x_scales.view( + 1, + in_shape[0] // 128, + in_shape[1] // 128, + MXFP8_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) + + alpha_tensor = get_cached_ones_tensor(num_groups, dtype, device) + norm_const_tensor = get_cached_ones_tensor(1, dtype, device) + current_stream = torch.cuda.current_stream().cuda_stream + + fc1_bias_packed = _pack_grouped_linear_bias_for_cudnn(fc1_op) + fc2_bias_packed = _pack_grouped_linear_bias_for_cudnn(fc2_op) + + fc1_glu_kwargs = { + "a_tensor": fc1_x_data, + "sfa_tensor": fc1_x_scales, + "padded_offsets": split_points, + "alpha_tensor": alpha_tensor, + "bias_tensor": fc1_bias_packed, + "norm_const_tensor": norm_const_tensor, + "prob_tensor": scales.detach().to(dtype=dtype).reshape(-1, 1, 1), + "acc_dtype": torch.float32, + "c_dtype": torch.bfloat16, + "d_dtype": torch.float8_e4m3fn, + "cd_major": "n", + "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, + "current_stream": current_stream, + "discrete_col_sfd": True, + "act_func": "swiglu", + "use_dynamic_sched": True, + } + + if fc1_op.single_grouped_weight: + # Clone and swizzle scales for GEMM. + fc1_weight_for_gemm = grouped_fc1_weight.copy() + tex.grouped_swizzle_for_gemm(fc1_weight_for_gemm, rowwise=True, columnwise=False) + + # Pack weight tensors for stacked kernel + # Data actual shape: (num_groups, n, k) + # Data logical shape: (n, k, num_groups) + fc1_w_data = fc1_weight_for_gemm.rowwise_data + fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) + fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1]) + fc1_w_data = fc1_w_data.permute(1, 2, 0) + fc1_w_scales = fc1_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu) + fc1_w_scales = fc1_w_scales.view( + num_groups, + fc1_weight_shape[0] // 128, + fc1_weight_shape[1] // 128, + MXFP8_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) + + fc1_glu_kwargs["b_tensor"] = fc1_w_data + fc1_glu_kwargs["sfb_tensor"] = fc1_w_scales + else: + # Discrete-weight kernel: per-expert data/scale pointers + fc1_b_ptrs, fc1_sfb_ptrs, _fc1_sw = tex.get_device_pointer_for_data_and_scales( + [w._rowwise_data for w in grouped_fc1_weight], + [w._rowwise_scale_inv for w in grouped_fc1_weight], + swizzle=True, + rowwise=True, + data_dtype=grouped_fc1_weight[0]._fp8_dtype, + ) + fc1_glu_kwargs["b_ptrs"] = fc1_b_ptrs + fc1_glu_kwargs["sfb_ptrs"] = fc1_sfb_ptrs + fc1_glu_kwargs["n"] = fc1_weight_shape[0] + fc1_glu_kwargs["b_dtype"] = torch.float8_e4m3fn + fc1_glu_kwargs["b_major"] = "k" + + fc1_kernel_out = self.grouped_gemm_glu_kernel()(**fc1_glu_kwargs) + + # Unpack kernel outputs + # Note: Fused kernel outputs tensors with non-contiguous + # logical dims. + # Row-wise data logical shape: (sum(m_splits), k, 1) + # Row-wise scale logical shape: (32 (block row), 4 (block row), + # sum(m_splits)/128, 4 (block col), k/128, 1) + # Column-wise data logical shape: (sum(m_splits), k, 1) + # Column-wise scale logical shape: (32 (block col), 4 (block col), + # k/128, 4 (block row), sum(m_splits)/128, 1) + swiglu_in = fc1_kernel_out["c_tensor"] + swiglu_in = swiglu_in.view(in_shape[0], fc1_weight_shape[0]) + fc2_in_row_data = fc1_kernel_out["d_tensor"] + fc2_in_row_data = fc2_in_row_data.view(in_shape[0], fc2_weight_shape[1]) + fc2_in_row_scale = fc1_kernel_out["sfd_row_tensor"] + fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 4, 0, 1, 3) + + fc2_in_col_data = fc1_kernel_out["d_col_tensor"] + fc2_in_col_data = fc2_in_col_data.view(in_shape[0], fc2_weight_shape[1]) + fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] + fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3) + # Repack columnwise scales on GPU to preserve group ordering. + + # FC2 inputs scales are already swizzled/optimized for GEMM + grouped_fc2_x = GroupedTensor( + shape=(in_shape[0], fc2_weight_shape[1]), + dtype=dtype, + num_tensors=num_groups, + quantizer=fc2_input_quantizer, + data=fc2_in_row_data.reshape(-1), + columnwise_data=fc2_in_col_data.reshape(-1), + scale_inv=fc2_in_row_scale.reshape(-1), + columnwise_scale_inv=fc2_in_col_scale.reshape(-1), + first_dims=split_sizes, + tensor_offsets=fc2_x_tensor_offsets, + with_gemm_swizzled_scales=True, + ) + + # FC2 GEMM + fc2_out_shape = in_shape[:-1] + [fc2_weight_shape[0]] + fc2_quant_kwargs = { + "a_tensor": fc1_kernel_out["d_tensor"], + "sfa_tensor": fc1_kernel_out["sfd_row_tensor"], + "padded_offsets": split_points, + "alpha_tensor": alpha_tensor.float(), + "norm_const_tensor": None, + "prob_tensor": torch.ones((in_shape[0], 1, 1), dtype=torch.float32, device=device), + "acc_dtype": torch.float32, + "c_dtype": dtype, + "d_dtype": dtype, + "cd_major": "n", + "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, + "current_stream": current_stream, + "use_dynamic_sched": True, + } + if self.is_fc2_bias_supported(): + fc2_quant_kwargs["bias_tensor"] = fc2_bias_packed + + if fc2_op.single_grouped_weight: + # Clone and swizzle scales for GEMM (original stays unmodified for save_for_backward) + fc2_weight_for_gemm = grouped_fc2_weight.copy() + tex.grouped_swizzle_for_gemm(fc2_weight_for_gemm, rowwise=True, columnwise=False) + + fc2_w_data = fc2_weight_for_gemm.rowwise_data + fc2_w_data = fc2_w_data.view(dtype=torch.float8_e4m3fn) + fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_shape[1]) + fc2_w_data = fc2_w_data.permute(1, 2, 0) + + fc2_w_scales = fc2_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu) + fc2_w_scales = fc2_w_scales.view( + num_groups, + fc2_weight_shape[0] // 128, + fc2_weight_shape[1] // 128, + MXFP8_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc2_w_scales = fc2_w_scales.permute(3, 4, 1, 5, 2, 0) + fc2_quant_kwargs["b_tensor"] = fc2_w_data + fc2_quant_kwargs["sfb_tensor"] = fc2_w_scales + else: + fc2_b_ptrs, fc2_sfb_ptrs, _ = tex.get_device_pointer_for_data_and_scales( + [w._rowwise_data for w in grouped_fc2_weight], + [w._rowwise_scale_inv for w in grouped_fc2_weight], + swizzle=True, + rowwise=True, + data_dtype=grouped_fc2_weight[0]._fp8_dtype, + ) + fc2_quant_kwargs["b_ptrs"] = fc2_b_ptrs + fc2_quant_kwargs["sfb_ptrs"] = fc2_sfb_ptrs + fc2_quant_kwargs["n"] = fc2_weight_shape[0] + fc2_quant_kwargs["b_dtype"] = torch.float8_e4m3fn + fc2_quant_kwargs["b_major"] = "k" + + fc2_kernel_out = self.grouped_gemm_quant_kernel()(**fc2_quant_kwargs) + fc2_out = fc2_kernel_out["d_tensor"].permute(2, 0, 1).view(fc2_out_shape).contiguous() + + # Save state for backward pass + if requires_grad: + mark_grouped_tensor(grouped_fc1_x, swiglu_in, scales, grouped_fc2_x) + fc1_input_tensors = ( + grouped_fc1_x.columnwise_data, + grouped_fc1_x.columnwise_scale_inv, + fc1_x_tensor_offsets, + ) + # FC1 + fc1_weight_tensors = ( + [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight + ) + fc1_ctx.save_for_backward( + split_sizes, split_points, *fc1_weight_tensors, *fc1_input_tensors + ) + fc1_ctx.with_quantized_compute = True + fc1_ctx.input_quantizer = fc1_input_quantizer + fc1_ctx.weight_quantizer = fc1_weight_quantizer + fc1_ctx.grad_output_quantizer = fc1_grad_output_quantizer + fc1_ctx.grad_input_quantizers = None + fc1_ctx.dtype = dtype + fc1_ctx.input_requires_grad = input_requires_grad + fc1_ctx.weight_requires_grad = weight_requires_grad + fc1_ctx.base_split_offsets = base_offsets + + # Scaled SwiGLU + swiglu_ctx.save_for_backward(swiglu_in, scales) + swiglu_ctx.input_requires_grad = True + swiglu_ctx.extra_input_requires_grad = True + swiglu_ctx.dtype = dtype + + # FC2 state + if grouped_fc2_x is not None: + fc2_input_tensors = ( + grouped_fc2_x.columnwise_data, + grouped_fc2_x.columnwise_scale_inv, + fc2_x_tensor_offsets, + ) + else: + fc2_input_tensors = (None, None, None) + + if fc2_op.single_grouped_weight: + fc2_ctx.save_for_backward(split_sizes, grouped_fc2_weight, *fc2_input_tensors) + else: + fc2_ctx.save_for_backward(split_sizes, *grouped_fc2_weight, *fc2_input_tensors) + + fc2_ctx.with_quantized_compute = True + fc2_ctx.input_quantizer = fc2_input_quantizer + fc2_ctx.weight_quantizer = fc2_weight_quantizer + fc2_ctx.grad_output_quantizer = fc2_grad_output_quantizer + fc2_ctx.grad_input_quantizers = None + fc2_ctx.dtype = dtype + fc2_ctx.input_requires_grad = input_requires_grad + fc2_ctx.weight_requires_grad = weight_requires_grad + + return fc2_out, [(), (), ()] + + +def fuse_forward_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + recipe : Recipe, optional + Quantization recipe. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + return fuse_grouped_mlp_ops( + ops, + recipe=recipe, + fused_op_cls=ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + ) + + +# Register fusion if available +if ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + register_forward_fusion(fuse_forward_ops, prepend=True) diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index 2fce9a38e2..ab0c7484fc 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -74,7 +74,7 @@ def __new__( dtype: torch.dtype, *, num_tensors: int, - shapes: Optional[List[Tuple[int, int]]] = None, + shapes: Optional[List[Tuple[int, ...]]] = None, quantizer: Optional[Quantizer] = None, data: Optional[torch.Tensor] = None, columnwise_data: Optional[torch.Tensor] = None, @@ -99,7 +99,15 @@ def __new__( and num_tensors > 0 and all(shapes[0] == s for s in shapes) ): - wrapper_shape = (num_tensors, shapes[0][0], shapes[0][1]) + s0 = shapes[0] + if len(s0) == 2: + wrapper_shape = (num_tensors, s0[0], s0[1]) + elif len(s0) == 1: + wrapper_shape = (num_tensors, s0[0]) + else: + raise ValueError( + f"GroupedTensor member shapes must be 1D or 2D, got {len(s0)}-D shape {s0!r}" + ) else: wrapper_shape = shape @@ -186,6 +194,7 @@ def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> Non dst.columnwise_scale_inv_offsets = src.columnwise_scale_inv_offsets dst.logical_shape = src.logical_shape dst.quantized_tensors = src.quantized_tensors + dst._with_gemm_swizzled_scales = src._with_gemm_swizzled_scales def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: """Create a wrapper of the same type and tensor metadata as src.""" diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 68097259c6..ff1c78f695 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -54,7 +54,7 @@ def _initialize_storage_fields( shape: Tuple[int, int], dtype: torch.dtype, num_tensors: int, - shapes: Optional[List[Tuple[int, int]]] = None, + shapes: Optional[List[Tuple[int, ...]]] = None, quantizer: Optional[Quantizer] = None, data: Optional[torch.Tensor] = None, columnwise_data: Optional[torch.Tensor] = None, @@ -153,7 +153,7 @@ def __new__( dtype: torch.dtype, *, num_tensors: int, - shapes: Optional[List[Tuple[int, int]]] = None, + shapes: Optional[List[Tuple[int, ...]]] = None, quantizer: Optional[Quantizer] = None, data: Optional[torch.Tensor] = None, columnwise_data: Optional[torch.Tensor] = None, @@ -383,6 +383,128 @@ def make_grouped_tensor_with_shapes( dtype=dtype, ) + @staticmethod + def make_grouped_tensor_from_rowwise_data( + *, + num_tensors: int, + tensor_shape: Tuple[int, ...], + rowwise_data: torch.Tensor, + dtype: Optional[torch.dtype] = None, + internal: bool = False, + ) -> GroupedTensorStorage: + """Wrap pre-existing contiguous rowwise data as a grouped tensor. + + This helper does not allocate storage. It creates grouped metadata over + `rowwise_data`, which is expected to contain `num_tensors` tensors of + shape ``tensor_shape`` in packed contiguous layout. + + ``tensor_shape`` may be: + + * ``(rows, cols)`` — each member is a 2D matrix; wrapper shape + ``(num_tensors, rows, cols)``. + * ``(n,)`` — each member is a 1D vector of length ``n``; logical storage + uses ``logical_shape = (num_tensors * n, 1)`` and the wrapper shape is + ``(num_tensors, n)``. + """ + if num_tensors <= 0: + raise ValueError(f"num_tensors must be positive, got {num_tensors}") + if rowwise_data is None: + raise ValueError("rowwise_data must not be None") + if not rowwise_data.is_contiguous(): + rowwise_data = rowwise_data.contiguous() + + if len(tensor_shape) == 2: + rows, cols = tensor_shape + expected_numel = num_tensors * rows * cols + logical_shape = (num_tensors * rows, cols) + shapes_list: List[Tuple[int, ...]] = [tensor_shape] * num_tensors + elif len(tensor_shape) == 1: + (n,) = tensor_shape + expected_numel = num_tensors * n + logical_shape = (num_tensors * n, 1) + shapes_list = [tensor_shape] * num_tensors + else: + raise ValueError( + "tensor_shape must be 1D (n,) or 2D (rows, cols), " + f"got {tensor_shape!r} with length {len(tensor_shape)}" + ) + + if rowwise_data.numel() != expected_numel: + raise ValueError( + "Grouped rowwise buffer size mismatch: expected " + f"{expected_numel} elements for {num_tensors}x{tensor_shape}, " + f"but got {rowwise_data.numel()}" + ) + if dtype is None: + dtype = rowwise_data.dtype + grouped_tensor_class = GroupedTensorStorage + if not internal: + from ..grouped_tensor import GroupedTensor + + grouped_tensor_class = GroupedTensor + + return grouped_tensor_class( + shape=logical_shape, + dtype=dtype, + num_tensors=num_tensors, + shapes=shapes_list, + quantizer=None, + data=rowwise_data.view(-1), + columnwise_data=None, + scale_inv=None, + columnwise_scale_inv=None, + amax=None, + columnwise_amax=None, + scale=None, + first_dims=None, + last_dims=None, + tensor_offsets=None, + offsets=None, + scale_inv_offsets=None, + columnwise_scale_inv_offsets=None, + with_gemm_swizzled_scales=False, + requires_grad=False, + ) + + def copy(self) -> "GroupedTensorStorage": + """Create a shallow copy that shares all data buffers with *self*. + No tensor data is copied; the returned object references the same + underlying storage for every buffer (data, scales, offsets, etc.). + This is useful when you need to mutate metadata (e.g. swizzle + scales in-place) without affecting the original object. + """ + return GroupedTensorStorage( + shape=self.logical_shape, + dtype=self.fake_dtype, + num_tensors=self.num_tensors, + shapes=self.tensor_shapes, + quantizer=self.quantizer, + data=self.rowwise_data, + columnwise_data=self.columnwise_data, + scale_inv=self.scale_inv, + columnwise_scale_inv=self.columnwise_scale_inv, + amax=self.amax, + columnwise_amax=self.columnwise_amax, + scale=self.scale, + first_dims=self.first_dims, + last_dims=self.last_dims, + tensor_offsets=self.tensor_offsets, + offsets=self.offsets, + scale_inv_offsets=self.scale_inv_offsets, + columnwise_scale_inv_offsets=self.columnwise_scale_inv_offsets, + with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + ) + + @staticmethod + def make_tensor_offsets(first_dims: torch.Tensor, logical_last_dim: int) -> torch.Tensor: + """Calculate GPU offsets from first dim splits.""" + return torch.cat( + [ + torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype), + torch.cumsum(first_dims * logical_last_dim, dim=0), + ] + ) + @staticmethod def make_grouped_tensor( num_tensors: int, @@ -421,7 +543,7 @@ def make_grouped_tensor( all_same_last = last_dims is None assert all_same_last, "Last dim must be uniform for GroupedTensor" - assert logical_first_dim > 0, "Logical first dim must be positive for GroupedTensor" + assert logical_first_dim >= 0, "Logical first dim must be non-negative for GroupedTensor" assert logical_last_dim > 0, "Logical last dim must be positive for GroupedTensor" # assert ( @@ -439,16 +561,20 @@ def make_grouped_tensor( # Kernels need to calculate precise pointers based on size of elements. # TODO(ksivaman): Single kernel + remove the host offset calculation. - tensor_offsets = torch.cat( - [ - torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype), - torch.cumsum(first_dims * logical_last_dim, dim=0), - ] - ) - offsets = tensor_offsets.tolist() - first_dims_list = first_dims.tolist() - for i in range(num_tensors): - shape.append((first_dims_list[i], logical_last_dim)) + tensor_offsets = GroupedTensorStorage.make_tensor_offsets(first_dims, logical_last_dim) + if ( + first_dims.device.type == "cuda" + and torch.cuda.is_available() + and torch.cuda.is_current_stream_capturing() + ): + # Avoid host sync during CUDA graph capture. + offsets = None + shape = None + else: + offsets = tensor_offsets.tolist() + first_dims_list = first_dims.tolist() + for i in range(num_tensors): + shape.append((first_dims_list[i], logical_last_dim)) else: offsets = [ i * logical_first_dim * logical_last_dim // num_tensors @@ -653,7 +779,6 @@ def make_grouped_tensor( quantizer.optimize_for_gemm if quantizer is not None else False ), ) - grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() return grouped_tensor @@ -709,7 +834,7 @@ def split_into_quantized_tensors( # Get tensor data slice if self.offsets is not None: start_offset = self.offsets[i] - numel = tensor_shape[0] * tensor_shape[1] + numel = math.prod(tensor_shape) end_offset = start_offset + numel if self.has_data(): @@ -724,7 +849,7 @@ def split_into_quantized_tensors( raise RuntimeError("GroupedTensor has no data to split") else: # All same shape case - numel = tensor_shape[0] * tensor_shape[1] + numel = math.prod(tensor_shape) start_offset = i * numel end_offset = start_offset + numel @@ -760,7 +885,7 @@ def split_into_quantized_tensors( quantizer = self.quantizer # Get tensor shape tensor_shape = self.tensor_shapes[i] - numel = tensor_shape[0] * tensor_shape[1] + numel = math.prod(tensor_shape) # Get data offsets if self.offsets is not None: diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index db2f28aa47..5a454e472f 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -19,6 +19,20 @@ __all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"] +@functools.lru_cache(maxsize=None) +def get_cached_ones_tensor( + num_elements: int, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + """Return a cached ``torch.ones`` tensor. + + Tensors are cached by ``(num_elements, dtype, device)`` and kept alive + by the cache, ensuring stable data pointers across CUDA graph replays. + """ + return torch.ones(num_elements, dtype=dtype, device=device) + + def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: """Check if any of the given tensors require gradient.""" for tensor in tensors: @@ -157,6 +171,25 @@ def divide(numerator: int, denominator: int) -> int: return numerator // denominator +def mark_grouped_tensor(*tensors: List[Any]): + """Mark tensors as grouped by setting their ``_grouped`` attribute. + Needed for paged stashing in MLM.""" + for tensor in tensors: + if tensor is None: + continue + if hasattr(tensor, "columnwise_data"): + assert ( + tensor.columnwise_data is not None + ), "Columnwise data is not set for grouped tensor" + assert ( + tensor.columnwise_scale_inv is not None + ), "Columnwise scale inverse is not set for grouped tensor" + setattr(tensor.columnwise_data, "grouped_tensor_scale_inv", False) + setattr(tensor.columnwise_scale_inv, "grouped_tensor_scale_inv", True) + else: + setattr(tensor, "grouped_tensor_scale_inv", False) + + def split_tensor_along_dim( tensor: torch.Tensor, dim: int, num_partitions: int, contiguous_split_chunks: bool = False ) -> Tuple[torch.Tensor, ...]: