Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Greptile SummaryThis PR adds MXFP8 (Microscaling FP8) support for JAX grouped GEMM operations, enabling CUDA-graph-safe quantization and GEMM execution paths. The changes span the full stack: CUDA kernels, C++ FFI handlers, Python primitives, and the Flax module interface.
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[grouped_gemm Python] --> B{Input Type?}
B -->|GroupedNoScaleTensor| C[Extract data + first_dims/last_dims]
B -->|GroupedScaledTensor1x| D[Extract data + scale_inv + first_dims/last_dims]
C --> E{Quantizer Set?}
D --> F{_can_use_v2?}
E -->|Yes| G[grouped_quantize]
E -->|No| F
G --> H{_use_v2_kernel?}
H -->|Yes MXFP8 + 128-aligned| I[V2 GroupedQuantize FFI\nCUDA-graph safe\nDevice-side offsets]
H -->|No| J[V1 GroupedQuantize FFI\nD2H copy of group_sizes]
I --> F
J --> F
F -->|Yes: BF16 or MXFP8 + SM100+ + 128-aligned| K[Pre-swizzle scales in JAX]
F -->|No| L[V1 GroupedGemm FFI\nnvte_multi_tensor_gemm]
K --> M[V2 GroupedGemm FFI\nnvte_grouped_gemm\nCUDA-graph safe]
L --> N[Output]
M --> N
Last reviewed commit: 833cb3e |
| assert False, ( | ||
| "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got" | ||
| " scaling_mode {}".format(scaling_mode) | ||
| ) |
There was a problem hiding this comment.
assert False makes fallback unreachable
The assert False statements at lines 1028, 1036, and 1045 will always raise AssertionError before the return False on the next line, making those returns dead code. More critically, if Python is run with optimizations enabled (-O flag, which disables asserts), the assert False becomes a no-op and execution falls through — the function would silently skip the validation and continue to later checks or return True, potentially routing data to the V2 kernel under unsupported conditions.
These should be changed to raise an explicit exception or simply return False (if fallback to V1 is the intended behavior) without using assert:
| assert False, ( | |
| "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got" | |
| " scaling_mode {}".format(scaling_mode) | |
| ) | |
| return False |
This same pattern repeats at lines 1036-1039 and 1044-1048.
| cudaMemcpyAsync(dim_list_host.data(), gs_data_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, | ||
| stream); | ||
| // Note: This may break cudaGraph. | ||
| cudaStreamSynchronize(stream); | ||
| } | ||
| // size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); | ||
| // if (!is_rhs_ragged) { | ||
| // NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, |
There was a problem hiding this comment.
Commented-out group_sizes sum validation
The validation that sum(group_sizes) matches m (or k for wgrad) has been commented out entirely. While the new *_first_dims/*_last_dims interface changes how dimensions are communicated, removing this runtime sanity check eliminates a useful guard against dimension mismatches that could lead to silent data corruption or out-of-bounds memory access. Consider either adapting this validation to work with the new interface or adding an equivalent check.
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
833cb3e to
2dd69d4
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
4fdbfca to
204b326
Compare
for more information, see https://pre-commit.ci
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: