Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 53 additions & 21 deletions dpdata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def add_atom_names(data, atom_names):


def sort_atom_names(data, type_map=None):
"""Sort atom_names of the system and reorder atom_numbs and atom_types accoarding
"""Sort atom_names of the system and reorder atom_numbs and atom_types according
to atom_names. If type_map is not given, atom_names will be sorted by
alphabetical order. If type_map is given, atom_names will be type_map.
alphabetical order. If type_map is given, atom_names will be set to type_map,
and zero-count elements are kept.

Parameters
----------
Expand All @@ -74,28 +75,59 @@ def sort_atom_names(data, type_map=None):
type_map
"""
if type_map is not None:
# assign atom_names index to the specify order
# atom_names must be a subset of type_map
assert set(data["atom_names"]).issubset(set(type_map))
# for the condition that type_map is a proper superset of atom_names
# new_atoms = set(type_map) - set(data["atom_names"])
new_atoms = [e for e in type_map if e not in data["atom_names"]]
if new_atoms:
data = add_atom_names(data, new_atoms)
# index that will sort an array by type_map
# a[as[a]] == b[as[b]] as == argsort
# as[as[b]] == as^{-1}[b]
# a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id]
idx = np.argsort(data["atom_names"], kind="stable")[
np.argsort(np.argsort(type_map, kind="stable"), kind="stable")
]
# assign atom_names index to the specified order
# only active (numb > 0) atom names must be in type_map
orig_names = data["atom_names"]
orig_numbs = data["atom_numbs"]
active_names = {name for name, numb in zip(orig_names, orig_numbs) if numb > 0}
type_map_set = set(type_map)
if not active_names.issubset(type_map_set):
missing = active_names - type_map_set
raise ValueError(f"Active atom types {missing} not in provided type_map.")

# for the condition that type_map is a proper superset of atom_names,
# we allow new elements with atom_numb = 0.
# Precompute name -> new index once to avoid repeated O(n_types)
# type_map.index(...) calls (which would make the loop O(n_types^2)).
name_to_new_idx = {name: i for i, name in enumerate(type_map)}

# Build new_numbs and the old->new lookup array in a single pass.
# Old names absent from type_map have atom_numb == 0 (validated above)
# and never appear in atom_types, so -1 is a harmless sentinel for
# their slots in the lookup table.
new_names = list(type_map)
new_numbs = [0] * len(type_map)
lookup = np.full(len(orig_names), -1, dtype=np.int64)
for old_idx, name in enumerate(orig_names):
new_idx = name_to_new_idx.get(name)
if new_idx is not None:
lookup[old_idx] = new_idx
new_numbs[new_idx] = orig_numbs[old_idx]

# Remap atom_types with a single vectorized fancy-index operation
# (O(n_atoms + n_types) instead of O(n_types * n_atoms)).
old_types = np.asarray(data["atom_types"])
new_types = lookup[old_types]

# update data in-place
data["atom_names"] = new_names
data["atom_numbs"] = new_numbs
data["atom_types"] = new_types

else:
# index that will sort an array by alphabetical order
# idx = argsort(atom_names) --> atom_names[idx] is sorted
idx = np.argsort(data["atom_names"], kind="stable")
# sort atom_names, atom_numbs, atom_types by idx
data["atom_names"] = list(np.array(data["atom_names"])[idx])
data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx])
data["atom_types"] = np.argsort(idx, kind="stable")[data["atom_types"]]
# sort atom_names and atom_numbs by idx
data["atom_names"] = list(np.array(data["atom_names"])[idx])
data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx])
# to update atom_types: we need the inverse permutation of idx
# because if old_type = i, and atom_names[i] moves to position j,
# then the new type should be j.
# inv_idx = argsort(idx) satisfies: inv_idx[idx[i]] = i
inv_idx = np.argsort(idx, kind="stable")
data["atom_types"] = inv_idx[data["atom_types"]]

return data


Expand Down
86 changes: 86 additions & 0 deletions tests/test_type_map_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

import unittest

import numpy as np

from dpdata.utils import sort_atom_names


class TestSortAtomNames(unittest.TestCase):
def test_sort_atom_names_type_map(self):
# Test basic functionality with type_map
data = {
"atom_names": ["H", "O"],
"atom_numbs": [2, 1],
"atom_types": np.array([1, 0, 0]),
}
type_map = ["O", "H"]
result = sort_atom_names(data, type_map=type_map)

self.assertEqual(result["atom_names"], ["O", "H"])
self.assertEqual(result["atom_numbs"], [1, 2])
np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1]))

def test_sort_atom_names_type_map_with_zero_atoms(self):
# Test with type_map that includes elements with zero atoms
data = {
"atom_names": ["H", "O"],
"atom_numbs": [2, 1],
"atom_types": np.array([1, 0, 0]),
}
type_map = ["O", "H", "C"] # C is not in atom_names but in type_map
result = sort_atom_names(data, type_map=type_map)

self.assertEqual(result["atom_names"], ["O", "H", "C"])
self.assertEqual(result["atom_numbs"], [1, 2, 0])
np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1]))

def test_sort_atom_names_type_map_missing_active_types(self):
# Test that ValueError is raised when active atom types are missing from type_map
data = {
"atom_names": ["H", "O"],
"atom_numbs": [2, 1], # Both H and O are active (numb > 0)
"atom_types": np.array([1, 0, 0]),
}
type_map = ["H"] # O is active but missing from type_map

with self.assertRaises(ValueError) as cm:
sort_atom_names(data, type_map=type_map)

self.assertIn("Active atom types", str(cm.exception))
self.assertIn("not in provided type_map", str(cm.exception))
self.assertIn("O", str(cm.exception))

def test_sort_atom_names_without_type_map(self):
# Test sorting without type_map (alphabetical order)
data = {
"atom_names": ["Zn", "O", "H"],
"atom_numbs": [1, 1, 2],
"atom_types": np.array([0, 1, 2, 2]),
}
result = sort_atom_names(data)

self.assertEqual(result["atom_names"], ["H", "O", "Zn"])
self.assertEqual(result["atom_numbs"], [2, 1, 1])
np.testing.assert_array_equal(result["atom_types"], np.array([2, 1, 0, 0]))

def test_sort_atom_names_with_zero_count_elements_removed(self):
# Test the case where original elements are A B C, but counts are 0 1 2,
# which should be able to map to B C (removing A which has count 0)
# Example: A, B, C = Cl, O, C
data = {
"atom_names": ["Cl", "O", "C"],
"atom_numbs": [0, 1, 2],
"atom_types": np.array([1, 2, 2]),
}
type_map = ["O", "C"] # Cl is omitted because it has 0 atoms
result = sort_atom_names(data, type_map=type_map)

self.assertEqual(result["atom_names"], ["O", "C"])
self.assertEqual(result["atom_numbs"], [1, 2])
np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1]))


if __name__ == "__main__":
unittest.main()