Skip to content

[JAX] MXFP8 Grouped GEMM#2763

Draft
jberchtold-nvidia wants to merge 26 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-mxfp8
Draft

[JAX] MXFP8 Grouped GEMM#2763
jberchtold-nvidia wants to merge 26 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-mxfp8

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

jberchtold-nvidia and others added 23 commits March 9, 2026 15:42
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft March 14, 2026 17:25
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 14, 2026

Greptile Summary

This PR adds MXFP8 (Microscaling FP8) support for JAX grouped GEMM operations, enabling CUDA-graph-safe quantization and GEMM execution paths. The changes span the full stack: CUDA kernels, C++ FFI handlers, Python primitives, and the Flax module interface.

  • New V2 grouped quantize kernel (GroupedQuantizeV2FFI): CUDA-graph-safe MXFP8 quantization that performs group-size conversion and prefix-sum offset computation entirely on-device, eliminating the D2H copy that made V1 incompatible with CUDA graphs.
  • Refactored grouped GEMM interface: Replaces single group_sizes array and scalar M/N/K parameters with per-dimension ragged arrays (first_dims/last_dims) for LHS, RHS, and output. This enables N-D tensor shapes to flow through to the C++ layer without manual flattening.
  • GroupedNoScaleTensor: New unquantized grouped tensor type that carries per-group dimension metadata, unifying the BF16 and FP8 code paths for grouped_gemm().
  • MXFP8 V2 GEMM path: Pre-swizzles scale tensors in JAX (CUDA-graph safe) and sets WithGEMMSwizzledScales on the grouped tensor wrapper for the cuBLAS call.
  • Bug fixes: Zero-sized group TMA descriptor crash fix in group_quantize_mxfp8.cuh; improved gradient sanitization in permutation.py (isfinite instead of isnan).
  • Flax integration: make_grouped_dense_cls now accepts MXFP8BlockScaling recipes (previously raised ValueError) and supports quantization_checkpoint_name.

Confidence Score: 3/5

  • Large feature PR with a critical assert False bug in the V2 kernel selection logic that could silently bypass validation in optimized Python mode.
  • The PR is a substantial feature addition (~1500 lines) touching CUDA kernels, C++ FFI handlers, and Python primitives across 16 files. The architecture is well-designed and includes good test coverage for V1/V2 kernel selection and gradients. However, the assert False pattern in _use_v2_kernel is a logic bug that would cause incorrect behavior when Python runs with -O, and the commented-out validation in the V1 C++ path weakens safety checks. The core CUDA and C++ changes appear sound.
  • transformer_engine/jax/cpp_extensions/quantization.py (assert False logic bug), transformer_engine/jax/csrc/extensions/gemm.cpp (commented-out validation)

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Major refactor of grouped GEMM to use N-D tensor shapes and per-dim ragged group sizes (first_dims/last_dims) instead of flattened M/N/K scalars. Adds MXFP8 V2 support with pre-swizzled scales.
transformer_engine/jax/cpp_extensions/quantization.py Adds V2 MXFP8 grouped quantize kernel selection. Contains assert False statements that create dead code and will silently skip validation if asserts are disabled.
transformer_engine/jax/csrc/extensions/gemm.cpp Complete rewrite of both V1 and V2 grouped GEMM FFI handlers to use per-dim ragged group sizes. Removes hardcoded M/N/K attrs and derives shapes from N-D XLA buffers. Commented-out validation is a concern.
transformer_engine/jax/csrc/extensions/quantization.cpp New GroupedQuantizeV2FFI handler that uses device-side int64 workspace for group sizes and offsets (CUDA-graph safe). Well-structured with proper bounds checking.
transformer_engine/jax/quantize/tensor.py Replaces single group_sizes with first_dims/last_dims on GroupedScaledTensor1x. Adds GroupedNoScaleTensor for unquantized grouped tensors. Proper pytree registration with tree_flatten/unflatten.
transformer_engine/jax/dense.py Updated fwd/bwd rules to wrap unquantized tensors in GroupedNoScaleTensor and use keyword args for grouped_gemm. Consistent with the new interface.
transformer_engine/jax/flax/module.py Enables MXFP8 recipes for grouped GEMM (previously raised ValueError). Adds quantization_checkpoint_name parameter. Uses assert instead of raise for recipe validation.
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Adds two new CUDA kernels: int32-to-int64 conversion with multiplier, and sequential prefix-sum offset computation. Both are simple and correct for small n_groups.
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh Important fix: skip TMA descriptor update for zero-sized groups to avoid CUDA_ERROR_ILLEGAL_ADDRESS from invalid zero-dimension TMA descriptors.
tests/jax/test_custom_call_compute.py Comprehensive test additions for MXFP8 V1/V2 kernel selection, including gradient tests. Updated existing tests to use GroupedNoScaleTensor and keyword args.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[grouped_gemm Python] --> B{Input Type?}
    B -->|GroupedNoScaleTensor| C[Extract data + first_dims/last_dims]
    B -->|GroupedScaledTensor1x| D[Extract data + scale_inv + first_dims/last_dims]
    C --> E{Quantizer Set?}
    D --> F{_can_use_v2?}
    E -->|Yes| G[grouped_quantize]
    E -->|No| F
    G --> H{_use_v2_kernel?}
    H -->|Yes MXFP8 + 128-aligned| I[V2 GroupedQuantize FFI\nCUDA-graph safe\nDevice-side offsets]
    H -->|No| J[V1 GroupedQuantize FFI\nD2H copy of group_sizes]
    I --> F
    J --> F
    F -->|Yes: BF16 or MXFP8 + SM100+ + 128-aligned| K[Pre-swizzle scales in JAX]
    F -->|No| L[V1 GroupedGemm FFI\nnvte_multi_tensor_gemm]
    K --> M[V2 GroupedGemm FFI\nnvte_grouped_gemm\nCUDA-graph safe]
    L --> N[Output]
    M --> N
Loading

Last reviewed commit: 833cb3e

Comment on lines +1028 to +1031
assert False, (
"V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got"
" scaling_mode {}".format(scaling_mode)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert False makes fallback unreachable

The assert False statements at lines 1028, 1036, and 1045 will always raise AssertionError before the return False on the next line, making those returns dead code. More critically, if Python is run with optimizations enabled (-O flag, which disables asserts), the assert False becomes a no-op and execution falls through — the function would silently skip the validation and continue to later checks or return True, potentially routing data to the V2 kernel under unsupported conditions.

These should be changed to raise an explicit exception or simply return False (if fallback to V1 is the intended behavior) without using assert:

Suggested change
assert False, (
"V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got"
" scaling_mode {}".format(scaling_mode)
)
return False

This same pattern repeats at lines 1036-1039 and 1044-1048.

Comment on lines +1078 to +1085
cudaMemcpyAsync(dim_list_host.data(), gs_data_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
}
// size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
// if (!is_rhs_ragged) {
// NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented-out group_sizes sum validation

The validation that sum(group_sizes) matches m (or k for wgrad) has been commented out entirely. While the new *_first_dims/*_last_dims interface changes how dimensions are communicated, removing this runtime sanity check eliminates a useful guard against dimension mismatches that could lead to silent data corruption or out-of-bounds memory access. Consider either adapting this validation to work with the new interface or adding an equivalent check.

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant