diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index f8c3685b78..d586d4e679 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -75,38 +75,116 @@ def make_stat_input( except StopIteration: iterator = iter(dataloaders[i]) stat_data = next(iterator) - if ( - "find_fparam" in stat_data - and "fparam" in stat_data - and stat_data["find_fparam"] == 0.0 - ): - # for model using default fparam - stat_data.pop("fparam") - stat_data.pop("find_fparam") - for dd in stat_data: - if stat_data[dd] is None: - sys_stat[dd] = None - elif isinstance(stat_data[dd], torch.Tensor): - if dd not in sys_stat: - sys_stat[dd] = [] - sys_stat[dd].append(stat_data[dd]) - elif isinstance(stat_data[dd], np.float32): - sys_stat[dd] = stat_data[dd] - else: - pass + _append_stat_data(sys_stat, stat_data) + _append_missing_type_frames(sys_stat, datasets[i]) for key in sys_stat: if isinstance(sys_stat[key], np.float32): pass elif sys_stat[key] is None or sys_stat[key][0] is None: sys_stat[key] = None - elif isinstance(stat_data[dd], torch.Tensor): + elif isinstance(sys_stat[key][0], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) lst.append(sys_stat) return lst +def _append_stat_data(sys_stat: dict[str, Any], stat_data: dict[str, Any]) -> None: + """Append one statistics batch to the per-system accumulator.""" + if ( + "find_fparam" in stat_data + and "fparam" in stat_data + and stat_data["find_fparam"] == 0.0 + ): + # for model using default fparam + stat_data.pop("fparam") + stat_data.pop("find_fparam") + for dd, value in stat_data.items(): + if value is None: + sys_stat[dd] = None + elif isinstance(value, torch.Tensor): + if dd not in sys_stat: + sys_stat[dd] = [] + sys_stat[dd].append(value) + elif isinstance(value, np.float32): + sys_stat[dd] = value + + +def _append_missing_type_frames(sys_stat: dict[str, Any], dataset: Any) -> None: + """Add representative mixed-type frames for atom types missed by sampling. + + Global output statistics solve a per-type linear regression from the sampled + frame compositions. In mixed-type datasets a random small sample can miss a + type that exists elsewhere in the dataset, making that type's bias + unconstrained. We therefore append a minimal set of real frames, one by one, + until every type present in the full system is also present in the statistics + sample. Non-mixed systems have a fixed composition, so the initially sampled + frames already cover the system-level types. + """ + if "real_natoms_vec" not in sys_stat or sys_stat["real_natoms_vec"] is None: + return + if not hasattr(dataset, "data_system"): + return + sampled_natoms_vec = sys_stat["real_natoms_vec"] + if len(sampled_natoms_vec) == 0: + return + sampled_counts = torch.cat(sampled_natoms_vec, dim=0)[:, 2:].sum(dim=0) + dataset_counts, first_frame_for_type = _mixed_type_coverage(dataset) + if dataset_counts is None or first_frame_for_type is None: + return + + missing_types = np.flatnonzero((dataset_counts > 0) & (sampled_counts.numpy() == 0)) + if len(missing_types) == 0: + return + + # Import lazily to keep this utility independent from dataloader import time. + from deepmd.pt.utils.dataloader import ( + collate_batch, + ) + + used_frames: set[int] = set() + while len(missing_types) > 0: + frame_idx = first_frame_for_type[int(missing_types[0])] + if frame_idx < 0 or frame_idx in used_frames: + break + used_frames.add(frame_idx) + # Reuse the dataset and collate path so that augmented frames have + # exactly the same tensor layout as ordinary DataLoader batches. + extra_batch = collate_batch([dataset[frame_idx]]) + _append_stat_data(sys_stat, extra_batch) + sampled_counts += extra_batch["real_natoms_vec"][:, 2:].sum(dim=0) + missing_types = np.flatnonzero( + (dataset_counts > 0) & (sampled_counts.numpy() == 0) + ) + + +def _mixed_type_coverage(dataset: Any) -> tuple[np.ndarray | None, np.ndarray | None]: + """Return full-dataset type counts and one representative frame per type.""" + data_system = dataset.data_system + if not getattr(data_system, "mixed_type", False): + return None, None + ntypes = data_system.get_ntypes() + counts = np.zeros(ntypes, dtype=np.int64) + first_frame_for_type = np.full(ntypes, -1, dtype=np.int64) + frame_offset = 0 + for set_dir, frame_end in zip( + data_system.dirs, data_system.prefix_sum, strict=True + ): + type_path = set_dir / "real_atom_types.npy" + real_type = type_path.load_numpy() + if getattr(data_system, "enforce_type_map", False): + real_type = data_system.type_idx_map[real_type].astype(np.int32) + real_type = real_type.reshape(frame_end - frame_offset, data_system.natoms) + for type_i in range(ntypes): + frame_hits = np.flatnonzero((real_type == type_i).any(axis=1)) + counts[type_i] += int((real_type == type_i).sum()) + if first_frame_for_type[type_i] < 0 and len(frame_hits) > 0: + first_frame_for_type[type_i] = frame_offset + int(frame_hits[0]) + frame_offset = frame_end + return counts, first_frame_for_type + + def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], diff --git a/deepmd/utils/model_stat.py b/deepmd/utils/model_stat.py index 33ebbcae57..0464758024 100644 --- a/deepmd/utils/model_stat.py +++ b/deepmd/utils/model_stat.py @@ -63,6 +63,7 @@ def collect_batches( if dd == "natoms_vec": stat_data[dd] = stat_data[dd].astype(np.int32) sys_stat[dd].append(stat_data[dd]) + _append_missing_type_frames(data, ii, sys_stat) for dd in sys_stat: if merge_sys: for bb in sys_stat[dd]: @@ -72,6 +73,86 @@ def collect_batches( return all_stat +def _append_missing_type_frames( + data: Any, sys_idx: int, sys_stat: dict[str, list[Any]] +) -> None: + """Append representative mixed-type frames for types missed by sampling. + + Energy/output bias statistics regress one bias per atom type from the sampled + frame compositions. Mixed-type systems can contain types that do not appear + in the small random statistics sample. When that happens, append the first + frame containing each missing type so the regression is constrained for every + type that exists in the underlying system. Standard (non-mixed) systems have + fixed composition and do not need augmentation. + """ + if "real_natoms_vec" not in sys_stat or not hasattr(data, "data_systems"): + return + if getattr(data, "mixed_systems", False): + # In mixed-system batching sys_idx is intentionally ignored by get_batch; + # keep the historical sampling behaviour rather than guessing ownership. + return + data_system = data.data_systems[sys_idx] + dataset_counts, first_frame_for_type = _mixed_type_coverage(data_system) + if dataset_counts is None or first_frame_for_type is None: + return + sampled_counts = np.concatenate(sys_stat["real_natoms_vec"], axis=0)[:, 2:].sum( + axis=0 + ) + missing_types = np.flatnonzero((dataset_counts > 0) & (sampled_counts == 0)) + if len(missing_types) == 0: + return + + used_frames: set[int] = set() + while len(missing_types) > 0: + frame_idx = first_frame_for_type[int(missing_types[0])] + if frame_idx < 0 or frame_idx in used_frames: + break + used_frames.add(frame_idx) + extra_batch = data_system.get_single_frame(frame_idx, num_worker=1) + extra_batch["natoms_vec"] = data.natoms_vec[sys_idx].astype(np.int32) + extra_batch["default_mesh"] = data.default_mesh[sys_idx] + for key, value in extra_batch.items(): + if key == "natoms_vec": + value = value.astype(np.int32) + if ( + key not in {"natoms_vec", "default_mesh"} + and isinstance(value, np.ndarray) + and value.ndim >= 1 + ): + value = value.reshape((1, *value.shape)) + sys_stat[key].append(value) + sampled_counts += ( + extra_batch["real_natoms_vec"].reshape(1, -1)[:, 2:].sum(axis=0) + ) + missing_types = np.flatnonzero((dataset_counts > 0) & (sampled_counts == 0)) + + +def _mixed_type_coverage( + data_system: Any, +) -> tuple[np.ndarray | None, np.ndarray | None]: + """Return full mixed-type counts and a representative frame per type.""" + if not getattr(data_system, "mixed_type", False): + return None, None + ntypes = data_system.get_ntypes() + counts = np.zeros(ntypes, dtype=np.int64) + first_frame_for_type = np.full(ntypes, -1, dtype=np.int64) + frame_offset = 0 + for set_dir, frame_end in zip( + data_system.dirs, data_system.prefix_sum, strict=True + ): + real_type = (set_dir / "real_atom_types.npy").load_numpy() + if getattr(data_system, "enforce_type_map", False): + real_type = data_system.type_idx_map[real_type].astype(np.int32) + real_type = real_type.reshape(frame_end - frame_offset, data_system.natoms) + for type_i in range(ntypes): + frame_hits = np.flatnonzero((real_type == type_i).any(axis=1)) + counts[type_i] += int((real_type == type_i).sum()) + if first_frame_for_type[type_i] < 0 and len(frame_hits) > 0: + first_frame_for_type[type_i] = frame_offset + int(frame_hits[0]) + frame_offset = frame_end + return counts, first_frame_for_type + + def make_stat_input( data: Any, nbatches: int, diff --git a/source/tests/common/test_model_stat.py b/source/tests/common/test_model_stat.py new file mode 100644 index 0000000000..b108684c3b --- /dev/null +++ b/source/tests/common/test_model_stat.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for backend-agnostic statistics sampling helpers.""" + +import unittest + +import numpy as np + +from deepmd.utils.model_stat import ( + make_stat_input, +) + + +class _FakeTypePath: + def __init__(self, real_types: np.ndarray) -> None: + self.real_types = real_types + + def load_numpy(self) -> np.ndarray: + return self.real_types + + +class _FakeSetDir: + def __init__(self, real_types: np.ndarray) -> None: + self.real_types = real_types + + def __truediv__(self, name: str) -> _FakeTypePath: + assert name == "real_atom_types.npy" + return _FakeTypePath(self.real_types) + + +class _FakeMixedDataSystem: + mixed_type = True + enforce_type_map = False + natoms = 2 + dirs: list[_FakeSetDir] + prefix_sum: list[int] + + def __init__(self) -> None: + self.dirs = [_FakeSetDir(np.array([[0, -1], [1, -1]], dtype=np.int32))] + self.prefix_sum = [2] + + def get_ntypes(self) -> int: + return 2 + + def get_single_frame(self, index: int, num_worker: int = 1) -> dict: + assert index == 1 + return { + "coord": np.zeros((6,), dtype=np.float32), + "type": np.array([1, -1], dtype=np.int32), + "atype": np.array([1, -1], dtype=np.int32), + "box": np.eye(3, dtype=np.float32).reshape(-1), + "real_natoms_vec": np.array([2, 2, 0, 1], dtype=np.int32), + "find_energy": np.float32(1.0), + "energy": np.array([1.0], dtype=np.float64), + } + + +class _FakeMixedData: + mixed_systems = False + natoms_vec: list[np.ndarray] + default_mesh: list[np.ndarray] + + def __init__(self) -> None: + self.data_systems = [_FakeMixedDataSystem()] + self.natoms_vec = [np.array([2, 2, 1, 0], dtype=np.int32)] + self.default_mesh = [np.array([], dtype=np.int32)] + + def get_nsystems(self) -> int: + return 1 + + def get_batch(self, sys_idx: int | None = None) -> dict: + assert sys_idx == 0 + return { + "coord": np.zeros((1, 6), dtype=np.float32), + "type": np.array([[0, -1]], dtype=np.int32), + "atype": np.array([[0, -1]], dtype=np.int32), + "box": np.eye(3, dtype=np.float32).reshape(1, 9), + "real_natoms_vec": np.array([[2, 2, 1, 0]], dtype=np.int32), + "natoms_vec": np.array([2, 2, 1, 0], dtype=np.int32), + "default_mesh": np.array([], dtype=np.int32), + "find_energy": np.float32(1.0), + "energy": np.array([[0.0]], dtype=np.float64), + } + + +class TestModelStatSamplingCoverage(unittest.TestCase): + """Mixed-type make_stat_input should cover types beyond initial batches.""" + + def test_make_stat_input_appends_missing_mixed_type_frame(self) -> None: + sampled = make_stat_input(_FakeMixedData(), nbatches=1) + + self.assertEqual(len(sampled), 1) + counts = sampled[0]["real_natoms_vec"][:, 2:].sum(axis=0) + self.assertTrue(np.all(counts > 0)) + self.assertEqual(sampled[0]["energy"].shape[0], 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_observed_type.py b/source/tests/pt/test_observed_type.py index c834ec567e..25cf58b07a 100644 --- a/source/tests/pt/test_observed_type.py +++ b/source/tests/pt/test_observed_type.py @@ -18,6 +18,7 @@ import torch from deepmd.pt.utils.stat import ( + _append_missing_type_frames, _restore_observed_type_from_file, _save_observed_type_to_file, collect_observed_types, @@ -72,6 +73,72 @@ def test_out_of_range_index_ignored(self) -> None: self.assertEqual(result, ["O"]) +class _FakeTypePath: + def __init__(self, real_types: np.ndarray) -> None: + self.real_types = real_types + + def load_numpy(self) -> np.ndarray: + return self.real_types + + +class _FakeSetDir: + def __init__(self, real_types: np.ndarray) -> None: + self.real_types = real_types + + def __truediv__(self, name: str) -> _FakeTypePath: + assert name == "real_atom_types.npy" + return _FakeTypePath(self.real_types) + + +class _FakeMixedDataSystem: + mixed_type = True + enforce_type_map = False + natoms = 2 + dirs: list[_FakeSetDir] + prefix_sum: list[int] + + def __init__(self) -> None: + self.dirs = [_FakeSetDir(np.array([[0, -1], [1, -1]], dtype=np.int32))] + self.prefix_sum = [2] + + def get_ntypes(self) -> int: + return 2 + + +class _FakeMixedDataset: + data_system: _FakeMixedDataSystem + + def __init__(self) -> None: + self.data_system = _FakeMixedDataSystem() + + def __getitem__(self, index: int) -> dict: + assert index == 1 + return { + "coord": np.zeros((6,), dtype=np.float32), + "atype": np.array([1, -1], dtype=np.int32), + "box": np.eye(3, dtype=np.float32).reshape(-1), + "real_natoms_vec": np.array([2, 2, 0, 1], dtype=np.int32), + } + + +class TestStatSamplingCoverage(unittest.TestCase): + """Mixed-type statistics samples should cover all dataset types.""" + + def test_append_missing_mixed_type_frame(self) -> None: + sys_stat = { + "coord": [torch.zeros((1, 6), dtype=torch.float32)], + "atype": [torch.tensor([[0, -1]], dtype=torch.int32)], + "box": [torch.eye(3, dtype=torch.float32).reshape(1, 9)], + "real_natoms_vec": [torch.tensor([[2, 2, 1, 0]], dtype=torch.int32)], + } + + _append_missing_type_frames(sys_stat, _FakeMixedDataset()) + + self.assertEqual(len(sys_stat["real_natoms_vec"]), 2) + sampled_counts = torch.cat(sys_stat["real_natoms_vec"], dim=0)[:, 2:].sum(dim=0) + self.assertTrue(torch.all(sampled_counts > 0)) + + class TestObservedTypeStatFile(unittest.TestCase): """Test stat file save/load round-trip for observed_type."""