Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 98 additions & 20 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,38 +75,116 @@ def make_stat_input(
except StopIteration:
iterator = iter(dataloaders[i])
stat_data = next(iterator)
if (
"find_fparam" in stat_data
and "fparam" in stat_data
and stat_data["find_fparam"] == 0.0
):
# for model using default fparam
stat_data.pop("fparam")
stat_data.pop("find_fparam")
for dd in stat_data:
if stat_data[dd] is None:
sys_stat[dd] = None
elif isinstance(stat_data[dd], torch.Tensor):
if dd not in sys_stat:
sys_stat[dd] = []
sys_stat[dd].append(stat_data[dd])
elif isinstance(stat_data[dd], np.float32):
sys_stat[dd] = stat_data[dd]
else:
pass
_append_stat_data(sys_stat, stat_data)
_append_missing_type_frames(sys_stat, datasets[i])

for key in sys_stat:
if isinstance(sys_stat[key], np.float32):
pass
elif sys_stat[key] is None or sys_stat[key][0] is None:
sys_stat[key] = None
elif isinstance(stat_data[dd], torch.Tensor):
elif isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)
lst.append(sys_stat)
return lst


def _append_stat_data(sys_stat: dict[str, Any], stat_data: dict[str, Any]) -> None:
"""Append one statistics batch to the per-system accumulator."""
if (
"find_fparam" in stat_data
and "fparam" in stat_data
and stat_data["find_fparam"] == 0.0
):
# for model using default fparam
stat_data.pop("fparam")
stat_data.pop("find_fparam")
for dd, value in stat_data.items():
if value is None:
sys_stat[dd] = None
elif isinstance(value, torch.Tensor):
if dd not in sys_stat:
sys_stat[dd] = []
sys_stat[dd].append(value)
elif isinstance(value, np.float32):
sys_stat[dd] = value


def _append_missing_type_frames(sys_stat: dict[str, Any], dataset: Any) -> None:
"""Add representative mixed-type frames for atom types missed by sampling.
Global output statistics solve a per-type linear regression from the sampled
frame compositions. In mixed-type datasets a random small sample can miss a
type that exists elsewhere in the dataset, making that type's bias
unconstrained. We therefore append a minimal set of real frames, one by one,
until every type present in the full system is also present in the statistics
sample. Non-mixed systems have a fixed composition, so the initially sampled
frames already cover the system-level types.
"""
if "real_natoms_vec" not in sys_stat or sys_stat["real_natoms_vec"] is None:
return
if not hasattr(dataset, "data_system"):
return
sampled_natoms_vec = sys_stat["real_natoms_vec"]
if len(sampled_natoms_vec) == 0:
return
sampled_counts = torch.cat(sampled_natoms_vec, dim=0)[:, 2:].sum(dim=0)
dataset_counts, first_frame_for_type = _mixed_type_coverage(dataset)
if dataset_counts is None or first_frame_for_type is None:
return

missing_types = np.flatnonzero((dataset_counts > 0) & (sampled_counts.numpy() == 0))
if len(missing_types) == 0:
return

# Import lazily to keep this utility independent from dataloader import time.
from deepmd.pt.utils.dataloader import (
collate_batch,
)

used_frames: set[int] = set()
while len(missing_types) > 0:
frame_idx = first_frame_for_type[int(missing_types[0])]
if frame_idx < 0 or frame_idx in used_frames:
break
used_frames.add(frame_idx)
# Reuse the dataset and collate path so that augmented frames have
# exactly the same tensor layout as ordinary DataLoader batches.
extra_batch = collate_batch([dataset[frame_idx]])
_append_stat_data(sys_stat, extra_batch)
sampled_counts += extra_batch["real_natoms_vec"][:, 2:].sum(dim=0)
missing_types = np.flatnonzero(
(dataset_counts > 0) & (sampled_counts.numpy() == 0)
)


def _mixed_type_coverage(dataset: Any) -> tuple[np.ndarray | None, np.ndarray | None]:
"""Return full-dataset type counts and one representative frame per type."""
data_system = dataset.data_system
if not getattr(data_system, "mixed_type", False):
return None, None
ntypes = data_system.get_ntypes()
counts = np.zeros(ntypes, dtype=np.int64)
first_frame_for_type = np.full(ntypes, -1, dtype=np.int64)
frame_offset = 0
for set_dir, frame_end in zip(
data_system.dirs, data_system.prefix_sum, strict=True
):
type_path = set_dir / "real_atom_types.npy"
real_type = type_path.load_numpy()
if getattr(data_system, "enforce_type_map", False):
real_type = data_system.type_idx_map[real_type].astype(np.int32)
real_type = real_type.reshape(frame_end - frame_offset, data_system.natoms)
for type_i in range(ntypes):
frame_hits = np.flatnonzero((real_type == type_i).any(axis=1))
counts[type_i] += int((real_type == type_i).sum())
if first_frame_for_type[type_i] < 0 and len(frame_hits) > 0:
first_frame_for_type[type_i] = frame_offset + int(frame_hits[0])
frame_offset = frame_end
return counts, first_frame_for_type


def _restore_from_file(
stat_file_path: DPPath,
keys: list[str] = ["energy"],
Expand Down
81 changes: 81 additions & 0 deletions deepmd/utils/model_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def collect_batches(
if dd == "natoms_vec":
stat_data[dd] = stat_data[dd].astype(np.int32)
sys_stat[dd].append(stat_data[dd])
_append_missing_type_frames(data, ii, sys_stat)
for dd in sys_stat:
if merge_sys:
for bb in sys_stat[dd]:
Expand All @@ -72,6 +73,86 @@ def collect_batches(
return all_stat


def _append_missing_type_frames(
data: Any, sys_idx: int, sys_stat: dict[str, list[Any]]
) -> None:
"""Append representative mixed-type frames for types missed by sampling.

Energy/output bias statistics regress one bias per atom type from the sampled
frame compositions. Mixed-type systems can contain types that do not appear
in the small random statistics sample. When that happens, append the first
frame containing each missing type so the regression is constrained for every
type that exists in the underlying system. Standard (non-mixed) systems have
fixed composition and do not need augmentation.
"""
if "real_natoms_vec" not in sys_stat or not hasattr(data, "data_systems"):
return
if getattr(data, "mixed_systems", False):
# In mixed-system batching sys_idx is intentionally ignored by get_batch;
# keep the historical sampling behaviour rather than guessing ownership.
return
data_system = data.data_systems[sys_idx]
dataset_counts, first_frame_for_type = _mixed_type_coverage(data_system)
if dataset_counts is None or first_frame_for_type is None:
return
sampled_counts = np.concatenate(sys_stat["real_natoms_vec"], axis=0)[:, 2:].sum(
axis=0
)
missing_types = np.flatnonzero((dataset_counts > 0) & (sampled_counts == 0))
if len(missing_types) == 0:
return

used_frames: set[int] = set()
while len(missing_types) > 0:
frame_idx = first_frame_for_type[int(missing_types[0])]
if frame_idx < 0 or frame_idx in used_frames:
break
used_frames.add(frame_idx)
extra_batch = data_system.get_single_frame(frame_idx, num_worker=1)
extra_batch["natoms_vec"] = data.natoms_vec[sys_idx].astype(np.int32)
extra_batch["default_mesh"] = data.default_mesh[sys_idx]
for key, value in extra_batch.items():
if key == "natoms_vec":
value = value.astype(np.int32)
if (
key not in {"natoms_vec", "default_mesh"}
and isinstance(value, np.ndarray)
and value.ndim >= 1
):
value = value.reshape((1, *value.shape))
sys_stat[key].append(value)
sampled_counts += (
extra_batch["real_natoms_vec"].reshape(1, -1)[:, 2:].sum(axis=0)
)
missing_types = np.flatnonzero((dataset_counts > 0) & (sampled_counts == 0))


def _mixed_type_coverage(
data_system: Any,
) -> tuple[np.ndarray | None, np.ndarray | None]:
"""Return full mixed-type counts and a representative frame per type."""
if not getattr(data_system, "mixed_type", False):
return None, None
ntypes = data_system.get_ntypes()
counts = np.zeros(ntypes, dtype=np.int64)
first_frame_for_type = np.full(ntypes, -1, dtype=np.int64)
frame_offset = 0
for set_dir, frame_end in zip(
data_system.dirs, data_system.prefix_sum, strict=True
):
real_type = (set_dir / "real_atom_types.npy").load_numpy()
if getattr(data_system, "enforce_type_map", False):
real_type = data_system.type_idx_map[real_type].astype(np.int32)
real_type = real_type.reshape(frame_end - frame_offset, data_system.natoms)
for type_i in range(ntypes):
frame_hits = np.flatnonzero((real_type == type_i).any(axis=1))
counts[type_i] += int((real_type == type_i).sum())
if first_frame_for_type[type_i] < 0 and len(frame_hits) > 0:
first_frame_for_type[type_i] = frame_offset + int(frame_hits[0])
frame_offset = frame_end
return counts, first_frame_for_type


def make_stat_input(
data: Any,
nbatches: int,
Expand Down
98 changes: 98 additions & 0 deletions source/tests/common/test_model_stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Tests for backend-agnostic statistics sampling helpers."""

import unittest

import numpy as np

from deepmd.utils.model_stat import (
make_stat_input,
)


class _FakeTypePath:
def __init__(self, real_types: np.ndarray) -> None:
self.real_types = real_types

def load_numpy(self) -> np.ndarray:
return self.real_types


class _FakeSetDir:
def __init__(self, real_types: np.ndarray) -> None:
self.real_types = real_types

def __truediv__(self, name: str) -> _FakeTypePath:
assert name == "real_atom_types.npy"
return _FakeTypePath(self.real_types)


class _FakeMixedDataSystem:
mixed_type = True
enforce_type_map = False
natoms = 2
dirs: list[_FakeSetDir]
prefix_sum: list[int]

def __init__(self) -> None:
self.dirs = [_FakeSetDir(np.array([[0, -1], [1, -1]], dtype=np.int32))]
self.prefix_sum = [2]

def get_ntypes(self) -> int:
return 2

def get_single_frame(self, index: int, num_worker: int = 1) -> dict:
assert index == 1
return {
"coord": np.zeros((6,), dtype=np.float32),
"type": np.array([1, -1], dtype=np.int32),
"atype": np.array([1, -1], dtype=np.int32),
"box": np.eye(3, dtype=np.float32).reshape(-1),
"real_natoms_vec": np.array([2, 2, 0, 1], dtype=np.int32),
"find_energy": np.float32(1.0),
"energy": np.array([1.0], dtype=np.float64),
}


class _FakeMixedData:
mixed_systems = False
natoms_vec: list[np.ndarray]
default_mesh: list[np.ndarray]

def __init__(self) -> None:
self.data_systems = [_FakeMixedDataSystem()]
self.natoms_vec = [np.array([2, 2, 1, 0], dtype=np.int32)]
self.default_mesh = [np.array([], dtype=np.int32)]

def get_nsystems(self) -> int:
return 1

def get_batch(self, sys_idx: int | None = None) -> dict:
assert sys_idx == 0
return {
"coord": np.zeros((1, 6), dtype=np.float32),
"type": np.array([[0, -1]], dtype=np.int32),
"atype": np.array([[0, -1]], dtype=np.int32),
"box": np.eye(3, dtype=np.float32).reshape(1, 9),
"real_natoms_vec": np.array([[2, 2, 1, 0]], dtype=np.int32),
"natoms_vec": np.array([2, 2, 1, 0], dtype=np.int32),
"default_mesh": np.array([], dtype=np.int32),
"find_energy": np.float32(1.0),
"energy": np.array([[0.0]], dtype=np.float64),
}


class TestModelStatSamplingCoverage(unittest.TestCase):
"""Mixed-type make_stat_input should cover types beyond initial batches."""

def test_make_stat_input_appends_missing_mixed_type_frame(self) -> None:
sampled = make_stat_input(_FakeMixedData(), nbatches=1)

self.assertEqual(len(sampled), 1)
counts = sampled[0]["real_natoms_vec"][:, 2:].sum(axis=0)
self.assertTrue(np.all(counts > 0))
self.assertEqual(sampled[0]["energy"].shape[0], 2)


if __name__ == "__main__":
unittest.main()
Loading
Loading