diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index 9dd965fa94..225c6f6759 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -464,3 +464,91 @@ def test_clear(self) -> None: assert grouped_tensor.num_tensors == 0 assert grouped_tensor.rowwise_data is None assert grouped_tensor.logical_shape == (0, 0) + + def test_grouped_linear_load_state_dict_multi_to_single_param(self, tmp_path) -> None: + """Load per-GEMM checkpoint from disk into single grouped parameter format.""" + num_gemms = 3 + in_features = 64 + out_features = 32 + dtype = torch.float32 + + src = te.GroupedLinear( + num_gemms=num_gemms, + in_features=in_features, + out_features=out_features, + params_dtype=dtype, + single_grouped_parameter=False, + ).cuda() + with torch.no_grad(): + for i in range(num_gemms): + getattr(src, f"weight{i}").copy_( + torch.randn(out_features, in_features, device="cuda", dtype=dtype) + ) + if src.use_bias: + getattr(src, f"bias{i}").copy_( + torch.randn(out_features, device="cuda", dtype=dtype) + ) + expected_weights = [getattr(src, f"weight{i}").detach().clone() for i in range(num_gemms)] + ckpt_path = tmp_path / "grouped_linear_per_gemm.pt" + torch.save(src.state_dict(), ckpt_path) + del src + + src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) + + dst = te.GroupedLinear( + num_gemms=num_gemms, + in_features=in_features, + out_features=out_features, + params_dtype=dtype, + single_grouped_parameter=True, + ).cuda() + load_result = dst.load_state_dict(src_state_dict, strict=True) + assert len(load_result.missing_keys) == 0 + assert len(load_result.unexpected_keys) == 0 + + assert getattr(dst, "weight", None) is not None + loaded_weights = dst.weight.split_into_quantized_tensors() + assert len(loaded_weights) == num_gemms + for loaded_weight, expected_weight in zip(loaded_weights, expected_weights): + assert torch.equal(loaded_weight, expected_weight) + + def test_grouped_linear_load_state_dict_single_to_multi_param(self, tmp_path) -> None: + """Load grouped-parameter checkpoint from disk into per-GEMM parameter format.""" + num_gemms = 3 + in_features = 64 + out_features = 32 + dtype = torch.float32 + + src = te.GroupedLinear( + num_gemms=num_gemms, + in_features=in_features, + out_features=out_features, + params_dtype=dtype, + single_grouped_parameter=True, + ).cuda() + with torch.no_grad(): + source_weights = src.weight.split_into_quantized_tensors() + for i in range(num_gemms): + source_weights[i].copy_( + torch.randn(out_features, in_features, device="cuda", dtype=dtype) + ) + expected_weights = [weight.detach().clone() for weight in source_weights] + ckpt_path = tmp_path / "grouped_linear_single_param.pt" + torch.save(src.state_dict(), ckpt_path) + del src + + src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) + + dst = te.GroupedLinear( + num_gemms=num_gemms, + in_features=in_features, + out_features=out_features, + params_dtype=dtype, + single_grouped_parameter=False, + ).cuda() + load_result = dst.load_state_dict(src_state_dict, strict=True) + assert len(load_result.missing_keys) == 0 + assert len(load_result.unexpected_keys) == 0 + + for i, expected_weight in enumerate(expected_weights): + assert torch.equal(getattr(dst, f"weight{i}"), expected_weight) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index fade2957d5..30c1dbf408 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -846,6 +846,77 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1) + def _remap_grouped_weight_state_dict_keys(self, state_dict, prefix: str) -> None: + """Remap weight keys between single and per-GEMM checkpoint formats.""" + grouped_weight_key = f"{prefix}weight" + per_gemm_weight_keys = [f"{prefix}weight{i}" for i in range(self.num_gemms)] + has_grouped_weight = grouped_weight_key in state_dict + has_per_gemm_weights = all(key in state_dict for key in per_gemm_weight_keys) + + if self.single_grouped_parameter: + # Backward compatibility: checkpoints saved without single_grouped_parameter + # store one weight tensor per GEMM (weight0..weightN). Convert them into a + # single stacked grouped weight expected by this module configuration. + if not has_grouped_weight and has_per_gemm_weights: + per_gemm_weights = [state_dict.pop(key) for key in per_gemm_weight_keys] + per_gemm_weights = [ + weight.dequantize() if isinstance(weight, QuantizedTensorStorage) else weight + for weight in per_gemm_weights + ] + state_dict[grouped_weight_key] = torch.stack(per_gemm_weights, dim=0) + elif has_grouped_weight: + # Drop any redundant per-GEMM keys to avoid strict-load unexpected-key errors. + for key in per_gemm_weight_keys: + state_dict.pop(key, None) + else: + # Forward compatibility: checkpoints saved with single_grouped_parameter + # store one grouped `weight`. Convert it back to weight0..weightN. + if not has_per_gemm_weights and has_grouped_weight: + grouped_weight = state_dict.pop(grouped_weight_key) + if hasattr(grouped_weight, "split_into_quantized_tensors"): + grouped_members = grouped_weight.quantized_tensors + if grouped_members is None: + grouped_members = grouped_weight.split_into_quantized_tensors() + per_gemm_weights = [ + ( + weight.dequantize() + if isinstance(weight, QuantizedTensorStorage) + else weight + ) + for weight in grouped_members + ] + else: + grouped_weight = ( + grouped_weight.dequantize() + if isinstance(grouped_weight, QuantizedTensorStorage) + else grouped_weight + ) + per_gemm_weights = list(grouped_weight.unbind(dim=0)) + for i, weight in enumerate(per_gemm_weights): + state_dict[f"{prefix}weight{i}"] = weight + elif has_per_gemm_weights: + # Drop any redundant grouped key to avoid strict-load unexpected-key errors. + state_dict.pop(grouped_weight_key, None) + + def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): + """Load state dict with grouped-weight format compatibility.""" + state_dict_copy = state_dict.copy() + metadata = getattr(state_dict, "_metadata", None) + if metadata is not None: + state_dict_copy._metadata = metadata + self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="") + return super().load_state_dict(state_dict_copy, strict=strict, assign=assign) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + """Load state, including compatibility across grouped-weight checkpoint formats.""" + self._remap_grouped_weight_state_dict_keys(state_dict, prefix) + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + @no_torch_dynamo() def forward( self,