-
Notifications
You must be signed in to change notification settings - Fork 695
[PyTorch] Backwards compatible single param checkpointing in GroupedLinear
#2761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
76c21b4
2066983
2fb67c6
9cd0dbe
540fa2f
dfd6f4b
ebd23b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -464,3 +464,91 @@ def test_clear(self) -> None: | |||||
| assert grouped_tensor.num_tensors == 0 | ||||||
| assert grouped_tensor.rowwise_data is None | ||||||
| assert grouped_tensor.logical_shape == (0, 0) | ||||||
|
|
||||||
| def test_grouped_linear_load_state_dict_multi_to_single_param(self, tmp_path) -> None: | ||||||
| """Load per-GEMM checkpoint from disk into single grouped parameter format.""" | ||||||
| num_gemms = 3 | ||||||
| in_features = 64 | ||||||
| out_features = 32 | ||||||
| dtype = torch.float32 | ||||||
|
|
||||||
| src = te.GroupedLinear( | ||||||
| num_gemms=num_gemms, | ||||||
| in_features=in_features, | ||||||
| out_features=out_features, | ||||||
| params_dtype=dtype, | ||||||
| single_grouped_parameter=False, | ||||||
| ).cuda() | ||||||
| with torch.no_grad(): | ||||||
| for i in range(num_gemms): | ||||||
| getattr(src, f"weight{i}").copy_( | ||||||
| torch.randn(out_features, in_features, device="cuda", dtype=dtype) | ||||||
| ) | ||||||
| if src.use_bias: | ||||||
| getattr(src, f"bias{i}").copy_( | ||||||
| torch.randn(out_features, device="cuda", dtype=dtype) | ||||||
| ) | ||||||
| expected_weights = [getattr(src, f"weight{i}").detach().clone() for i in range(num_gemms)] | ||||||
| ckpt_path = tmp_path / "grouped_linear_per_gemm.pt" | ||||||
| torch.save(src.state_dict(), ckpt_path) | ||||||
| del src | ||||||
|
|
||||||
| src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For the multi-to-single test (
Suggested change
The same concern applies to line 540 in |
||||||
|
|
||||||
| dst = te.GroupedLinear( | ||||||
| num_gemms=num_gemms, | ||||||
| in_features=in_features, | ||||||
| out_features=out_features, | ||||||
| params_dtype=dtype, | ||||||
| single_grouped_parameter=True, | ||||||
| ).cuda() | ||||||
| load_result = dst.load_state_dict(src_state_dict, strict=True) | ||||||
| assert len(load_result.missing_keys) == 0 | ||||||
| assert len(load_result.unexpected_keys) == 0 | ||||||
|
|
||||||
| assert getattr(dst, "weight", None) is not None | ||||||
| loaded_weights = dst.weight.split_into_quantized_tensors() | ||||||
| assert len(loaded_weights) == num_gemms | ||||||
| for loaded_weight, expected_weight in zip(loaded_weights, expected_weights): | ||||||
| assert torch.equal(loaded_weight, expected_weight) | ||||||
|
|
||||||
| def test_grouped_linear_load_state_dict_single_to_multi_param(self, tmp_path) -> None: | ||||||
| """Load grouped-parameter checkpoint from disk into per-GEMM parameter format.""" | ||||||
| num_gemms = 3 | ||||||
| in_features = 64 | ||||||
| out_features = 32 | ||||||
| dtype = torch.float32 | ||||||
|
|
||||||
| src = te.GroupedLinear( | ||||||
| num_gemms=num_gemms, | ||||||
| in_features=in_features, | ||||||
| out_features=out_features, | ||||||
| params_dtype=dtype, | ||||||
| single_grouped_parameter=True, | ||||||
| ).cuda() | ||||||
| with torch.no_grad(): | ||||||
| source_weights = src.weight.split_into_quantized_tensors() | ||||||
| for i in range(num_gemms): | ||||||
| source_weights[i].copy_( | ||||||
| torch.randn(out_features, in_features, device="cuda", dtype=dtype) | ||||||
| ) | ||||||
| expected_weights = [weight.detach().clone() for weight in source_weights] | ||||||
| ckpt_path = tmp_path / "grouped_linear_single_param.pt" | ||||||
| torch.save(src.state_dict(), ckpt_path) | ||||||
| del src | ||||||
|
|
||||||
| src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) | ||||||
|
|
||||||
| dst = te.GroupedLinear( | ||||||
| num_gemms=num_gemms, | ||||||
| in_features=in_features, | ||||||
| out_features=out_features, | ||||||
| params_dtype=dtype, | ||||||
| single_grouped_parameter=False, | ||||||
| ).cuda() | ||||||
| load_result = dst.load_state_dict(src_state_dict, strict=True) | ||||||
| assert len(load_result.missing_keys) == 0 | ||||||
| assert len(load_result.unexpected_keys) == 0 | ||||||
|
|
||||||
| for i, expected_weight in enumerate(expected_weights): | ||||||
| assert torch.equal(getattr(dst, f"weight{i}"), expected_weight) | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -846,6 +846,77 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: | |
| elif self.parallel_mode == "column": | ||
| set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1) | ||
|
|
||
| def _remap_grouped_weight_state_dict_keys(self, state_dict, prefix: str) -> None: | ||
| """Remap weight keys between single and per-GEMM checkpoint formats.""" | ||
| grouped_weight_key = f"{prefix}weight" | ||
| per_gemm_weight_keys = [f"{prefix}weight{i}" for i in range(self.num_gemms)] | ||
| has_grouped_weight = grouped_weight_key in state_dict | ||
| has_per_gemm_weights = all(key in state_dict for key in per_gemm_weight_keys) | ||
|
|
||
| if self.single_grouped_parameter: | ||
| # Backward compatibility: checkpoints saved without single_grouped_parameter | ||
| # store one weight tensor per GEMM (weight0..weightN). Convert them into a | ||
| # single stacked grouped weight expected by this module configuration. | ||
| if not has_grouped_weight and has_per_gemm_weights: | ||
| per_gemm_weights = [state_dict.pop(key) for key in per_gemm_weight_keys] | ||
| per_gemm_weights = [ | ||
| weight.dequantize() if isinstance(weight, QuantizedTensorStorage) else weight | ||
| for weight in per_gemm_weights | ||
| ] | ||
| state_dict[grouped_weight_key] = torch.stack(per_gemm_weights, dim=0) | ||
| elif has_grouped_weight: | ||
| # Drop any redundant per-GEMM keys to avoid strict-load unexpected-key errors. | ||
| for key in per_gemm_weight_keys: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might need this even for TE sequential checkpointing right? Maybe putting it in utils and using it in both places make sense to avoid code duplication |
||
| state_dict.pop(key, None) | ||
| else: | ||
| # Forward compatibility: checkpoints saved with single_grouped_parameter | ||
| # store one grouped `weight`. Convert it back to weight0..weightN. | ||
| if not has_per_gemm_weights and has_grouped_weight: | ||
| grouped_weight = state_dict.pop(grouped_weight_key) | ||
| if hasattr(grouped_weight, "split_into_quantized_tensors"): | ||
| grouped_members = grouped_weight.quantized_tensors | ||
| if grouped_members is None: | ||
| grouped_members = grouped_weight.split_into_quantized_tensors() | ||
| per_gemm_weights = [ | ||
| ( | ||
| weight.dequantize() | ||
| if isinstance(weight, QuantizedTensorStorage) | ||
| else weight | ||
| ) | ||
| for weight in grouped_members | ||
| ] | ||
| else: | ||
| grouped_weight = ( | ||
| grouped_weight.dequantize() | ||
| if isinstance(grouped_weight, QuantizedTensorStorage) | ||
| else grouped_weight | ||
| ) | ||
| per_gemm_weights = list(grouped_weight.unbind(dim=0)) | ||
| for i, weight in enumerate(per_gemm_weights): | ||
| state_dict[f"{prefix}weight{i}"] = weight | ||
|
Comment on lines
+874
to
+896
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No validation of GEMM count after splitting When splitting a grouped checkpoint into per-GEMM weights, neither the Adding an explicit early check here improves debuggability: if hasattr(grouped_weight, "split_into_quantized_tensors"):
grouped_members = grouped_weight.quantized_tensors
if grouped_members is None:
grouped_members = grouped_weight.split_into_quantized_tensors()
if len(grouped_members) != self.num_gemms:
raise ValueError(
f"Checkpoint grouped weight contains {len(grouped_members)} GEMMs "
f"but this module was configured with num_gemms={self.num_gemms}."
)
...
else:
per_gemm_weights = list(grouped_weight.unbind(dim=0))
if len(per_gemm_weights) != self.num_gemms:
raise ValueError(
f"Checkpoint stacked weight has {len(per_gemm_weights)} slices along dim=0 "
f"but this module was configured with num_gemms={self.num_gemms}."
) |
||
| elif has_per_gemm_weights: | ||
| # Drop any redundant grouped key to avoid strict-load unexpected-key errors. | ||
| state_dict.pop(grouped_weight_key, None) | ||
|
|
||
| def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): | ||
| """Load state dict with grouped-weight format compatibility.""" | ||
| state_dict_copy = state_dict.copy() | ||
| metadata = getattr(state_dict, "_metadata", None) | ||
| if metadata is not None: | ||
| state_dict_copy._metadata = metadata | ||
| self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="") | ||
| return super().load_state_dict(state_dict_copy, strict=strict, assign=assign) | ||
|
Comment on lines
+901
to
+908
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Double remapping of weight keys
The second call is idempotent — after the first remap the state dict is already in the expected format, so the second remap is a no-op — but the redundancy is a maintenance hazard: a future change that makes the remap non-idempotent could silently introduce data corruption (e.g. double-stacking weights). A straightforward fix is to skip the remap inside def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
"""Load state dict with grouped-weight format compatibility."""
state_dict_copy = state_dict.copy()
metadata = getattr(state_dict, "_metadata", None)
if metadata is not None:
state_dict_copy._metadata = metadata
# Key remapping is performed in _load_from_state_dict which PyTorch
# calls internally; no need to remap again here.
return super().load_state_dict(state_dict_copy, strict=strict, assign=assign)This keeps the copy (protecting the caller's dict) and relies on
Comment on lines
+901
to
+908
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When A fix is to either document that def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
"""Load state dict with grouped-weight format compatibility."""
if assign:
warnings.warn(
"GroupedLinear.load_state_dict with assign=True does not support "
"cross-format checkpoint loading. Use assign=False (default).",
UserWarning,
)
state_dict_copy = state_dict.copy()
metadata = getattr(state_dict, "_metadata", None)
if metadata is not None:
state_dict_copy._metadata = metadata
self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="")
return super().load_state_dict(state_dict_copy, strict=strict, assign=assign) |
||
|
|
||
| def _load_from_state_dict( | ||
| self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | ||
| ): | ||
| """Load state, including compatibility across grouped-weight checkpoint formats.""" | ||
| self._remap_grouped_weight_state_dict_keys(state_dict, prefix) | ||
|
|
||
| super()._load_from_state_dict( | ||
| state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | ||
| ) | ||
|
Comment on lines
+910
to
+918
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When While the key operations are scoped by
A defensive pattern is to work on a shallow copy of only the module's relevant key-space, similar to what |
||
|
|
||
| @no_torch_dynamo() | ||
| def forward( | ||
| self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also add test case for quantized_model_init(mxfp8)? Shouldnt be a blocker for this PR though.