diff --git a/dpdata/utils.py b/dpdata/utils.py index 58a908cc..197d8333 100644 --- a/dpdata/utils.py +++ b/dpdata/utils.py @@ -62,9 +62,10 @@ def add_atom_names(data, atom_names): def sort_atom_names(data, type_map=None): - """Sort atom_names of the system and reorder atom_numbs and atom_types accoarding + """Sort atom_names of the system and reorder atom_numbs and atom_types according to atom_names. If type_map is not given, atom_names will be sorted by - alphabetical order. If type_map is given, atom_names will be type_map. + alphabetical order. If type_map is given, atom_names will be set to type_map, + and zero-count elements are kept. Parameters ---------- @@ -74,28 +75,59 @@ def sort_atom_names(data, type_map=None): type_map """ if type_map is not None: - # assign atom_names index to the specify order - # atom_names must be a subset of type_map - assert set(data["atom_names"]).issubset(set(type_map)) - # for the condition that type_map is a proper superset of atom_names - # new_atoms = set(type_map) - set(data["atom_names"]) - new_atoms = [e for e in type_map if e not in data["atom_names"]] - if new_atoms: - data = add_atom_names(data, new_atoms) - # index that will sort an array by type_map - # a[as[a]] == b[as[b]] as == argsort - # as[as[b]] == as^{-1}[b] - # a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id] - idx = np.argsort(data["atom_names"], kind="stable")[ - np.argsort(np.argsort(type_map, kind="stable"), kind="stable") - ] + # assign atom_names index to the specified order + # only active (numb > 0) atom names must be in type_map + orig_names = data["atom_names"] + orig_numbs = data["atom_numbs"] + active_names = {name for name, numb in zip(orig_names, orig_numbs) if numb > 0} + type_map_set = set(type_map) + if not active_names.issubset(type_map_set): + missing = active_names - type_map_set + raise ValueError(f"Active atom types {missing} not in provided type_map.") + + # for the condition that type_map is a proper superset of atom_names, + # we allow new elements with atom_numb = 0. + # Precompute name -> new index once to avoid repeated O(n_types) + # type_map.index(...) calls (which would make the loop O(n_types^2)). + name_to_new_idx = {name: i for i, name in enumerate(type_map)} + + # Build new_numbs and the old->new lookup array in a single pass. + # Old names absent from type_map have atom_numb == 0 (validated above) + # and never appear in atom_types, so -1 is a harmless sentinel for + # their slots in the lookup table. + new_names = list(type_map) + new_numbs = [0] * len(type_map) + lookup = np.full(len(orig_names), -1, dtype=np.int64) + for old_idx, name in enumerate(orig_names): + new_idx = name_to_new_idx.get(name) + if new_idx is not None: + lookup[old_idx] = new_idx + new_numbs[new_idx] = orig_numbs[old_idx] + + # Remap atom_types with a single vectorized fancy-index operation + # (O(n_atoms + n_types) instead of O(n_types * n_atoms)). + old_types = np.asarray(data["atom_types"]) + new_types = lookup[old_types] + + # update data in-place + data["atom_names"] = new_names + data["atom_numbs"] = new_numbs + data["atom_types"] = new_types + else: # index that will sort an array by alphabetical order + # idx = argsort(atom_names) --> atom_names[idx] is sorted idx = np.argsort(data["atom_names"], kind="stable") - # sort atom_names, atom_numbs, atom_types by idx - data["atom_names"] = list(np.array(data["atom_names"])[idx]) - data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx]) - data["atom_types"] = np.argsort(idx, kind="stable")[data["atom_types"]] + # sort atom_names and atom_numbs by idx + data["atom_names"] = list(np.array(data["atom_names"])[idx]) + data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx]) + # to update atom_types: we need the inverse permutation of idx + # because if old_type = i, and atom_names[i] moves to position j, + # then the new type should be j. + # inv_idx = argsort(idx) satisfies: inv_idx[idx[i]] = i + inv_idx = np.argsort(idx, kind="stable") + data["atom_types"] = inv_idx[data["atom_types"]] + return data diff --git a/tests/test_type_map_utils.py b/tests/test_type_map_utils.py new file mode 100644 index 00000000..4ae9567d --- /dev/null +++ b/tests/test_type_map_utils.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import unittest + +import numpy as np + +from dpdata.utils import sort_atom_names + + +class TestSortAtomNames(unittest.TestCase): + def test_sort_atom_names_type_map(self): + # Test basic functionality with type_map + data = { + "atom_names": ["H", "O"], + "atom_numbs": [2, 1], + "atom_types": np.array([1, 0, 0]), + } + type_map = ["O", "H"] + result = sort_atom_names(data, type_map=type_map) + + self.assertEqual(result["atom_names"], ["O", "H"]) + self.assertEqual(result["atom_numbs"], [1, 2]) + np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1])) + + def test_sort_atom_names_type_map_with_zero_atoms(self): + # Test with type_map that includes elements with zero atoms + data = { + "atom_names": ["H", "O"], + "atom_numbs": [2, 1], + "atom_types": np.array([1, 0, 0]), + } + type_map = ["O", "H", "C"] # C is not in atom_names but in type_map + result = sort_atom_names(data, type_map=type_map) + + self.assertEqual(result["atom_names"], ["O", "H", "C"]) + self.assertEqual(result["atom_numbs"], [1, 2, 0]) + np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1])) + + def test_sort_atom_names_type_map_missing_active_types(self): + # Test that ValueError is raised when active atom types are missing from type_map + data = { + "atom_names": ["H", "O"], + "atom_numbs": [2, 1], # Both H and O are active (numb > 0) + "atom_types": np.array([1, 0, 0]), + } + type_map = ["H"] # O is active but missing from type_map + + with self.assertRaises(ValueError) as cm: + sort_atom_names(data, type_map=type_map) + + self.assertIn("Active atom types", str(cm.exception)) + self.assertIn("not in provided type_map", str(cm.exception)) + self.assertIn("O", str(cm.exception)) + + def test_sort_atom_names_without_type_map(self): + # Test sorting without type_map (alphabetical order) + data = { + "atom_names": ["Zn", "O", "H"], + "atom_numbs": [1, 1, 2], + "atom_types": np.array([0, 1, 2, 2]), + } + result = sort_atom_names(data) + + self.assertEqual(result["atom_names"], ["H", "O", "Zn"]) + self.assertEqual(result["atom_numbs"], [2, 1, 1]) + np.testing.assert_array_equal(result["atom_types"], np.array([2, 1, 0, 0])) + + def test_sort_atom_names_with_zero_count_elements_removed(self): + # Test the case where original elements are A B C, but counts are 0 1 2, + # which should be able to map to B C (removing A which has count 0) + # Example: A, B, C = Cl, O, C + data = { + "atom_names": ["Cl", "O", "C"], + "atom_numbs": [0, 1, 2], + "atom_types": np.array([1, 2, 2]), + } + type_map = ["O", "C"] # Cl is omitted because it has 0 atoms + result = sort_atom_names(data, type_map=type_map) + + self.assertEqual(result["atom_names"], ["O", "C"]) + self.assertEqual(result["atom_numbs"], [1, 2]) + np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1])) + + +if __name__ == "__main__": + unittest.main()