Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions tests/pytorch/test_grouped_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,3 +464,91 @@ def test_clear(self) -> None:
assert grouped_tensor.num_tensors == 0
assert grouped_tensor.rowwise_data is None
assert grouped_tensor.logical_shape == (0, 0)

def test_grouped_linear_load_state_dict_multi_to_single_param(self, tmp_path) -> None:
"""Load per-GEMM checkpoint from disk into single grouped parameter format."""
num_gemms = 3
in_features = 64
out_features = 32
dtype = torch.float32

src = te.GroupedLinear(
num_gemms=num_gemms,
in_features=in_features,
out_features=out_features,
params_dtype=dtype,
single_grouped_parameter=False,
).cuda()
with torch.no_grad():
for i in range(num_gemms):
getattr(src, f"weight{i}").copy_(
torch.randn(out_features, in_features, device="cuda", dtype=dtype)
)
if src.use_bias:
getattr(src, f"bias{i}").copy_(
torch.randn(out_features, device="cuda", dtype=dtype)
)
expected_weights = [getattr(src, f"weight{i}").detach().clone() for i in range(num_gemms)]
ckpt_path = tmp_path / "grouped_linear_per_gemm.pt"
torch.save(src.state_dict(), ckpt_path)
del src
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add test case for quantized_model_init(mxfp8)? Shouldnt be a blocker for this PR though.


src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weights_only=False enables arbitrary pickle execution

torch.load(..., weights_only=False) deserialises the file using Python's pickle module, which executes arbitrary code embedded in the file. PyTorch 2.x already emits a FutureWarning for this pattern and the default will flip to True in a future release.

For the multi-to-single test (test_grouped_linear_load_state_dict_multi_to_single_param) the source model uses single_grouped_parameter=False, so all saved tensors are plain torch.Tensor objects — weights_only=True should work fine there.

Suggested change
src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)

The same concern applies to line 540 in test_grouped_linear_load_state_dict_single_to_multi_param. For that test the saved weight is a GroupedTensor subclass, which may require weights_only=False to deserialise; if so, the incompatibility should be documented with an inline comment explaining why weights_only=True cannot be used.


dst = te.GroupedLinear(
num_gemms=num_gemms,
in_features=in_features,
out_features=out_features,
params_dtype=dtype,
single_grouped_parameter=True,
).cuda()
load_result = dst.load_state_dict(src_state_dict, strict=True)
assert len(load_result.missing_keys) == 0
assert len(load_result.unexpected_keys) == 0

assert getattr(dst, "weight", None) is not None
loaded_weights = dst.weight.split_into_quantized_tensors()
assert len(loaded_weights) == num_gemms
for loaded_weight, expected_weight in zip(loaded_weights, expected_weights):
assert torch.equal(loaded_weight, expected_weight)

def test_grouped_linear_load_state_dict_single_to_multi_param(self, tmp_path) -> None:
"""Load grouped-parameter checkpoint from disk into per-GEMM parameter format."""
num_gemms = 3
in_features = 64
out_features = 32
dtype = torch.float32

src = te.GroupedLinear(
num_gemms=num_gemms,
in_features=in_features,
out_features=out_features,
params_dtype=dtype,
single_grouped_parameter=True,
).cuda()
with torch.no_grad():
source_weights = src.weight.split_into_quantized_tensors()
for i in range(num_gemms):
source_weights[i].copy_(
torch.randn(out_features, in_features, device="cuda", dtype=dtype)
)
expected_weights = [weight.detach().clone() for weight in source_weights]
ckpt_path = tmp_path / "grouped_linear_single_param.pt"
torch.save(src.state_dict(), ckpt_path)
del src

src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)

dst = te.GroupedLinear(
num_gemms=num_gemms,
in_features=in_features,
out_features=out_features,
params_dtype=dtype,
single_grouped_parameter=False,
).cuda()
load_result = dst.load_state_dict(src_state_dict, strict=True)
assert len(load_result.missing_keys) == 0
assert len(load_result.unexpected_keys) == 0

for i, expected_weight in enumerate(expected_weights):
assert torch.equal(getattr(dst, f"weight{i}"), expected_weight)
71 changes: 71 additions & 0 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,77 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None:
elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1)

def _remap_grouped_weight_state_dict_keys(self, state_dict, prefix: str) -> None:
"""Remap weight keys between single and per-GEMM checkpoint formats."""
grouped_weight_key = f"{prefix}weight"
per_gemm_weight_keys = [f"{prefix}weight{i}" for i in range(self.num_gemms)]
has_grouped_weight = grouped_weight_key in state_dict
has_per_gemm_weights = all(key in state_dict for key in per_gemm_weight_keys)

if self.single_grouped_parameter:
# Backward compatibility: checkpoints saved without single_grouped_parameter
# store one weight tensor per GEMM (weight0..weightN). Convert them into a
# single stacked grouped weight expected by this module configuration.
if not has_grouped_weight and has_per_gemm_weights:
per_gemm_weights = [state_dict.pop(key) for key in per_gemm_weight_keys]
per_gemm_weights = [
weight.dequantize() if isinstance(weight, QuantizedTensorStorage) else weight
for weight in per_gemm_weights
]
state_dict[grouped_weight_key] = torch.stack(per_gemm_weights, dim=0)
elif has_grouped_weight:
# Drop any redundant per-GEMM keys to avoid strict-load unexpected-key errors.
for key in per_gemm_weight_keys:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need this even for TE sequential checkpointing right? Maybe putting it in utils and using it in both places make sense to avoid code duplication

state_dict.pop(key, None)
else:
# Forward compatibility: checkpoints saved with single_grouped_parameter
# store one grouped `weight`. Convert it back to weight0..weightN.
if not has_per_gemm_weights and has_grouped_weight:
grouped_weight = state_dict.pop(grouped_weight_key)
if hasattr(grouped_weight, "split_into_quantized_tensors"):
grouped_members = grouped_weight.quantized_tensors
if grouped_members is None:
grouped_members = grouped_weight.split_into_quantized_tensors()
per_gemm_weights = [
(
weight.dequantize()
if isinstance(weight, QuantizedTensorStorage)
else weight
)
for weight in grouped_members
]
else:
grouped_weight = (
grouped_weight.dequantize()
if isinstance(grouped_weight, QuantizedTensorStorage)
else grouped_weight
)
per_gemm_weights = list(grouped_weight.unbind(dim=0))
for i, weight in enumerate(per_gemm_weights):
state_dict[f"{prefix}weight{i}"] = weight
Comment on lines +874 to +896
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No validation of GEMM count after splitting

When splitting a grouped checkpoint into per-GEMM weights, neither the split_into_quantized_tensors() path nor the unbind(dim=0) path validates that the number of recovered tensors equals self.num_gemms. If the checkpoint was created with a different number of GEMMs (e.g., num_gemms=5 saved, num_gemms=3 loaded), the remap will silently inject weight0..4 into the state dict. With strict=True, PyTorch will then report weight3 and weight4 as unexpected keys — but the diagnostic message gives no hint that the root cause is a GEMM-count mismatch.

Adding an explicit early check here improves debuggability:

if hasattr(grouped_weight, "split_into_quantized_tensors"):
    grouped_members = grouped_weight.quantized_tensors
    if grouped_members is None:
        grouped_members = grouped_weight.split_into_quantized_tensors()
    if len(grouped_members) != self.num_gemms:
        raise ValueError(
            f"Checkpoint grouped weight contains {len(grouped_members)} GEMMs "
            f"but this module was configured with num_gemms={self.num_gemms}."
        )
    ...
else:
    per_gemm_weights = list(grouped_weight.unbind(dim=0))
    if len(per_gemm_weights) != self.num_gemms:
        raise ValueError(
            f"Checkpoint stacked weight has {len(per_gemm_weights)} slices along dim=0 "
            f"but this module was configured with num_gemms={self.num_gemms}."
        )

elif has_per_gemm_weights:
# Drop any redundant grouped key to avoid strict-load unexpected-key errors.
state_dict.pop(grouped_weight_key, None)

def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
"""Load state dict with grouped-weight format compatibility."""
state_dict_copy = state_dict.copy()
metadata = getattr(state_dict, "_metadata", None)
if metadata is not None:
state_dict_copy._metadata = metadata
self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="")
return super().load_state_dict(state_dict_copy, strict=strict, assign=assign)
Comment on lines +901 to +908
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double remapping of weight keys

_remap_grouped_weight_state_dict_keys is applied twice whenever GroupedLinear.load_state_dict is the entry point:

  1. Explicitly in load_state_dict (line 907).
  2. Again inside GroupedLinear._load_from_state_dict (line 914), which PyTorch's super().load_state_dict() invokes internally as part of its recursive loading loop.

The second call is idempotent — after the first remap the state dict is already in the expected format, so the second remap is a no-op — but the redundancy is a maintenance hazard: a future change that makes the remap non-idempotent could silently introduce data corruption (e.g. double-stacking weights).

A straightforward fix is to skip the remap inside load_state_dict and let _load_from_state_dict handle it exclusively (which already covers the nested-module case). The copy is still needed to avoid mutating the caller's dict, so it should be preserved:

def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
    """Load state dict with grouped-weight format compatibility."""
    state_dict_copy = state_dict.copy()
    metadata = getattr(state_dict, "_metadata", None)
    if metadata is not None:
        state_dict_copy._metadata = metadata
    # Key remapping is performed in _load_from_state_dict which PyTorch
    # calls internally; no need to remap again here.
    return super().load_state_dict(state_dict_copy, strict=strict, assign=assign)

This keeps the copy (protecting the caller's dict) and relies on _load_from_state_dict for the single, canonical remap path in all cases.

Comment on lines +901 to +908
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assign=True replaces GroupedTensor with a plain tensor

When assign=True is passed to load_state_dict and the multi-to-single conversion is active (single_grouped_parameter=True, checkpoint has weight0..N), _remap_grouped_weight_state_dict_keys writes a plain torch.Tensor (from torch.stack) into state_dict_copy["weight"]. PyTorch's assign=True path then calls setattr(module, "weight", plain_tensor) instead of param.copy_(plain_tensor), so the GroupedTensor parameter is silently replaced by a plain tensor. Any subsequent forward pass that calls self.weight.split_into_quantized_tensors() or relies on the GroupedTensor.__torch_dispatch__ mechanism will crash or silently compute incorrect results.

A fix is to either document that assign=True is unsupported for cross-format loading, or reconstruct a proper GroupedTensor inside the remap helper when the target format is single_grouped_parameter=True:

def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
    """Load state dict with grouped-weight format compatibility."""
    if assign:
        warnings.warn(
            "GroupedLinear.load_state_dict with assign=True does not support "
            "cross-format checkpoint loading. Use assign=False (default).",
            UserWarning,
        )
    state_dict_copy = state_dict.copy()
    metadata = getattr(state_dict, "_metadata", None)
    if metadata is not None:
        state_dict_copy._metadata = metadata
    self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="")
    return super().load_state_dict(state_dict_copy, strict=strict, assign=assign)


def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
"""Load state, including compatibility across grouped-weight checkpoint formats."""
self._remap_grouped_weight_state_dict_keys(state_dict, prefix)

super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
Comment on lines +910 to +918
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_load_from_state_dict mutates the shared state dict in-place

When GroupedLinear is used as a submodule, PyTorch passes the same state_dict object (with the full prefix tree) to every module's _load_from_state_dict. This override calls _remap_grouped_weight_state_dict_keys, which modifies that shared dict in-place — popping old keys and inserting new ones (e.g. swapping "parent.grouped.weight" out for "parent.grouped.weight0..N").

While the key operations are scoped by prefix and don't touch other modules' keys, the mutation is a side-effect that:

  1. Permanently alters the caller's state dict after the fact (the user may not expect their dict to be modified when loading a submodule).
  2. Interacts unexpectedly with the unexpected_keys accounting in PyTorch's base _load_from_state_dict if the newly injected keys are not all consumed.

A defensive pattern is to work on a shallow copy of only the module's relevant key-space, similar to what load_state_dict already does at the top level. At minimum, adding a comment here that the mutation is intentional and scoped would reduce the maintenance burden.


@no_torch_dynamo()
def forward(
self,
Expand Down
Loading