diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 1a01b05fe9..43771c3c53 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -32,6 +32,9 @@ from deepmd.pt.model.task.sezm_ener import ( SeZMEnergyFittingNet, ) +from deepmd.pt.utils.multi_task import ( + preprocess_shared_params, +) from deepmd.utils.spin import ( Spin, ) @@ -155,19 +158,56 @@ def get_spin_model(model_params: dict) -> SpinModel: def get_linear_model(model_params: dict) -> LinearEnergyModel: model_params = copy.deepcopy(model_params) weights = model_params.get("weights", "mean") + shared_links = None + if "shared_dict" in model_params: + shared_config = { + "model_dict": { + f"model_{idx}": sub_model + for idx, sub_model in enumerate(model_params["models"]) + }, + "shared_dict": model_params.get("shared_dict", {}), + } + if "type_map" in model_params: + shared_config["type_map"] = copy.deepcopy(model_params["type_map"]) + shared_config, shared_links = preprocess_shared_params( + shared_config, + require_shared_type_map=False, + ) + model_params["models"] = list(shared_config["model_dict"].values()) + if "type_map" not in model_params: + for idx, sub_model_params in enumerate(model_params["models"]): + if "type_map" not in sub_model_params: + raise ValueError( + f"Linear sub-model {idx} must define type_map when " + "linear_ener has no top-level type_map." + ) + first_type_map = model_params["models"][0]["type_map"] + for idx, sub_model_params in enumerate(model_params["models"][1:], start=1): + if sub_model_params["type_map"] != first_type_map: + raise ValueError( + f"Linear sub-model {idx} type_map differs from sub-model 0. " + "All type_map values must be identical when linear_ener " + "has no top-level type_map." + ) + model_params["type_map"] = copy.deepcopy(first_type_map) + list_of_models = [] - ntypes = len(model_params["type_map"]) for sub_model_params in model_params["models"]: if "type_map" not in sub_model_params: sub_model_params["type_map"] = model_params["type_map"] if "descriptor" in sub_model_params: # descriptor - sub_model_params["descriptor"]["ntypes"] = ntypes + sub_ntypes = len(sub_model_params["type_map"]) + sub_model_params["descriptor"]["ntypes"] = sub_ntypes descriptor, fitting, _ = _get_standard_model_components( - sub_model_params, ntypes + sub_model_params, sub_ntypes ) list_of_models.append( - DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"]) + DPAtomicModel( + descriptor, + fitting, + type_map=copy.deepcopy(sub_model_params["type_map"]), + ) ) else: # must be pairtab @@ -179,19 +219,23 @@ def get_linear_model(model_params: dict) -> LinearEnergyModel: sub_model_params["tab_file"], sub_model_params["rcut"], sub_model_params["sel"], - type_map=model_params["type_map"], + type_map=copy.deepcopy(sub_model_params["type_map"]), ) ) atom_exclude_types = model_params.get("atom_exclude_types", []) pair_exclude_types = model_params.get("pair_exclude_types", []) - return LinearEnergyModel( + model = LinearEnergyModel( models=list_of_models, type_map=model_params["type_map"], weights=weights, atom_exclude_types=atom_exclude_types, pair_exclude_types=pair_exclude_types, ) + model.shared_links = shared_links + if shared_links: + model.share_params(shared_links, resume=True) + return model def get_zbl_model(model_params: dict) -> DPZBLModel: diff --git a/deepmd/pt/model/model/dp_linear_model.py b/deepmd/pt/model/model/dp_linear_model.py index c3004b5a5d..3255e1da20 100644 --- a/deepmd/pt/model/model/dp_linear_model.py +++ b/deepmd/pt/model/model/dp_linear_model.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from copy import ( + deepcopy, +) from typing import ( Any, ) @@ -14,6 +18,9 @@ from deepmd.pt.model.model.model import ( BaseModel, ) +from deepmd.pt.utils.multi_task import ( + preprocess_shared_params, +) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -25,6 +32,8 @@ make_model, ) +log = logging.getLogger(__name__) + DPLinearModel_ = make_model(LinearEnergyAtomicModel) @@ -39,6 +48,108 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) + def share_params( + self, + shared_links: dict[str, Any], + model_key_prob_map: dict[str, float] | None = None, + data_stat_protect: float = 1e-2, + resume: bool = False, + ) -> None: + """Share parameters between linear sub-models. + + ``shared_links`` follows the same structure as the multi-task + preprocessor. Linear sub-model keys are named ``model_0``, ``model_1``, + ... by ``get_linear_model``. + """ + + def get_sub_model(model_key: str): # noqa: ANN202 + if not model_key.startswith("model_"): + raise RuntimeError(f"Unknown linear model key {model_key}!") + model_index = int(model_key.removeprefix("model_")) + return self.atomic_model.models[model_index] + + def get_descriptor_class(model_key: str, shared_type: str): # noqa: ANN202 + sub_model = get_sub_model(model_key) + if shared_type == "descriptor": + return sub_model.descriptor + if "hybrid" in shared_type: + hybrid_index = int(shared_type.split("_")[-1]) + return sub_model.descriptor.descrpt_list[hybrid_index] + raise RuntimeError(f"Unknown class_type {shared_type}!") + + for shared_item in shared_links: + shared_base = shared_links[shared_item]["links"][0] + class_type_base = shared_base["shared_type"] + model_key_base = shared_base["model_key"] + shared_level_base = int(shared_base["shared_level"]) + previous_shared_level = shared_level_base + if "descriptor" in class_type_base: + base_class = get_descriptor_class(model_key_base, class_type_base) + for link_item in shared_links[shared_item]["links"][1:]: + class_type_link = link_item["shared_type"] + model_key_link = link_item["model_key"] + shared_level_link = int(link_item["shared_level"]) + if shared_level_link < previous_shared_level: + raise ValueError( + "The shared_links must be sorted by shared_level!" + ) + previous_shared_level = shared_level_link + if "descriptor" not in class_type_link: + raise ValueError( + f"Class type mismatched: {class_type_base} vs {class_type_link}!" + ) + link_class = get_descriptor_class(model_key_link, class_type_link) + link_class.share_params( + base_class, shared_level_link, resume=resume + ) + log.warning( + "Shared params of %s.%s and %s.%s!", + model_key_base, + class_type_base, + model_key_link, + class_type_link, + ) + else: + base_model = get_sub_model(model_key_base) + if hasattr(base_model, class_type_base): + base_class = getattr(base_model, class_type_base) + for link_item in shared_links[shared_item]["links"][1:]: + class_type_link = link_item["shared_type"] + model_key_link = link_item["model_key"] + shared_level_link = int(link_item["shared_level"]) + if shared_level_link < previous_shared_level: + raise ValueError( + "The shared_links must be sorted by shared_level!" + ) + previous_shared_level = shared_level_link + if class_type_base != class_type_link: + raise ValueError( + f"Class type mismatched: {class_type_base} vs {class_type_link}!" + ) + link_model = get_sub_model(model_key_link) + link_class = getattr(link_model, class_type_link) + if model_key_prob_map is None: + frac_prob = 1.0 + else: + frac_prob = ( + model_key_prob_map[model_key_link] + / model_key_prob_map[model_key_base] + ) + link_class.share_params( + base_class, + shared_level_link, + model_prob=frac_prob, + protection=data_stat_protect, + resume=resume, + ) + log.warning( + "Shared params of %s.%s and %s.%s!", + model_key_base, + class_type_base, + model_key_link, + class_type_link, + ) + def translated_output_def(self) -> dict[str, OutputVariableDef]: out_def_data = self.model_output_def().get_data() output_def = { @@ -159,14 +270,71 @@ def update_sel( float The minimum distance between two atoms """ - local_jdata_cpy = local_jdata.copy() + local_jdata_cpy = deepcopy(local_jdata) + original_models = deepcopy(local_jdata_cpy["models"]) + has_shared_dict = "shared_dict" in local_jdata_cpy + if has_shared_dict: + shared_config = { + "model_dict": { + f"model_{idx}": sub_model + for idx, sub_model in enumerate(local_jdata_cpy["models"]) + }, + "shared_dict": local_jdata_cpy.get("shared_dict", {}), + } + if "type_map" in local_jdata_cpy: + shared_config["type_map"] = deepcopy(local_jdata_cpy["type_map"]) + shared_config, _ = preprocess_shared_params( + shared_config, + require_shared_type_map=False, + ) + local_jdata_cpy["models"] = list(shared_config["model_dict"].values()) + if "type_map" not in local_jdata_cpy: + local_jdata_cpy["type_map"] = deepcopy( + local_jdata_cpy["models"][0]["type_map"] + ) type_map = local_jdata_cpy["type_map"] min_nbor_dist = None for idx, sub_model in enumerate(local_jdata_cpy["models"]): if "tab_file" not in sub_model: - sub_model, temp_min = DPModelCommon.update_sel( - train_data, type_map, local_jdata["models"][idx] + sub_type_map = sub_model.get("type_map", type_map) + local_jdata_cpy["models"][idx], temp_min = DPModelCommon.update_sel( + train_data, sub_type_map, sub_model ) if min_nbor_dist is None or temp_min <= min_nbor_dist: min_nbor_dist = temp_min - return local_jdata_cpy, min_nbor_dist + if not has_shared_dict: + return local_jdata_cpy, min_nbor_dist + + def get_shared_key(shared_ref: str) -> str: + return shared_ref.split(":", maxsplit=1)[0] + + ret_jdata = deepcopy(local_jdata) + ret_jdata["models"] = original_models + if "type_map" not in ret_jdata: + ret_jdata["type_map"] = deepcopy(type_map) + for idx, original_sub_model in enumerate(original_models): + if "tab_file" in original_sub_model: + continue + updated_sub_model = local_jdata_cpy["models"][idx] + descriptor_ref = original_sub_model.get("descriptor") + if isinstance(descriptor_ref, str): + ret_jdata["shared_dict"][get_shared_key(descriptor_ref)] = ( + updated_sub_model["descriptor"] + ) + elif ( + isinstance(descriptor_ref, dict) + and descriptor_ref.get("type") == "hybrid" + ): + updated_descriptor = updated_sub_model["descriptor"] + for hybrid_idx, hybrid_ref in enumerate(descriptor_ref["list"]): + if isinstance(hybrid_ref, str): + ret_jdata["shared_dict"][get_shared_key(hybrid_ref)] = ( + updated_descriptor["list"][hybrid_idx] + ) + else: + ret_jdata["models"][idx]["descriptor"]["list"][hybrid_idx] = ( + updated_descriptor["list"][hybrid_idx] + ) + else: + ret_jdata["models"][idx]["descriptor"] = updated_sub_model["descriptor"] + return ret_jdata, min_nbor_dist diff --git a/deepmd/pt/utils/multi_task.py b/deepmd/pt/utils/multi_task.py index d99ac704c7..0b1c12ced0 100644 --- a/deepmd/pt/utils/multi_task.py +++ b/deepmd/pt/utils/multi_task.py @@ -36,11 +36,17 @@ def _cascade_top_level_defaults(model_config: dict[str, Any]) -> None: def preprocess_shared_params( model_config: dict[str, Any], + require_shared_type_map: bool = True, ) -> tuple[dict[str, Any], dict[str, Any]]: - """Preprocess the model params for multitask model, and generate the links dict for further sharing. + """Preprocess shared model params and generate links for parameter sharing. Args: - model_config: Model params of multitask model. + model_config: Model params containing ``model_dict`` and optional + ``shared_dict``. + require_shared_type_map: Whether exactly one ``type_map`` entry must be + referenced from ``shared_dict``. Multi-task training keeps this + requirement; linear models may inherit an explicit top-level + ``type_map`` instead. Returns ------- @@ -167,7 +173,10 @@ def replace_one_item( item_params = model_params_item[item_key] if isinstance(item_params, str): replace_one_item(model_params_item, item_key, item_params) - elif item_params.get("type", "") == "hybrid": + elif ( + isinstance(item_params, dict) + and item_params.get("type", "") == "hybrid" + ): for ii, hybrid_item in enumerate(item_params["list"]): if isinstance(hybrid_item, str): replace_one_item( @@ -187,7 +196,10 @@ def replace_one_item( ) # little trick to make spin models in the front to be the base models, # because its type embeddings are more general. - assert len(type_map_keys) == 1, "Multitask model must have only one type_map!" + if require_shared_type_map: + assert len(type_map_keys) == 1, "Multitask model must have only one type_map!" + else: + assert len(type_map_keys) <= 1, "Shared params must have at most one type_map!" return model_config, shared_links diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 19fbe8cebd..c58dfd3a18 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3388,6 +3388,7 @@ def linear_ener_model_args() -> Argument: 'If "mean", the weights are set to be 1 / len(models). ' 'If "sum", the weights are set to be 1.' ) + doc_shared_dict = "The definition of the shared parameters used in the `models` within linear model." models_args = model_args(exclude_hybrid=True) models_args.name = "models" models_args.fold_subdoc = True @@ -3405,6 +3406,9 @@ def linear_ener_model_args() -> Argument: optional=False, doc=doc_weights, ), + Argument( + "shared_dict", dict, optional=True, default={}, doc=doc_shared_dict + ), ], doc=doc_only_tf_supported, ) diff --git a/source/tests/pt/model/test_linear_model_shared_dict.py b/source/tests/pt/model/test_linear_model_shared_dict.py new file mode 100644 index 0000000000..605dff9d8e --- /dev/null +++ b/source/tests/pt/model/test_linear_model_shared_dict.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from unittest.mock import ( + patch, +) + +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + + +class TestLinearEnergySharedDict(unittest.TestCase): + def assert_dpa1_descriptor_shared(self, descriptor0, descriptor1) -> None: + self.assertIs(descriptor1.type_embedding, descriptor0.type_embedding) + self.assertGreater(len(descriptor0.se_atten._modules), 0) + for module_name, module in descriptor0.se_atten._modules.items(): + self.assertIs(descriptor1.se_atten._modules[module_name], module) + + def make_dpa1_descriptor(self, seed: int) -> dict: + return { + "type": "dpa1", + "sel": 4, + "rcut_smth": 0.5, + "rcut": 6.0, + "neuron": [4, 8, 16], + "axis_neuron": 4, + "seed": seed, + } + + def make_se_e2_a_descriptor(self, sel: str) -> dict: + return { + "type": "se_e2_a", + "rcut": 6.0, + "sel": sel, + } + + def make_fitting_net(self, seed: int) -> dict: + return { + "neuron": [8, 8, 8], + "resnet_dt": True, + "seed": seed, + } + + def test_shared_dict_descriptor_and_type_map(self) -> None: + config = { + "type": "linear_ener", + "shared_dict": { + "type_map_all": ["O", "H"], + "dpa1_descriptor": { + "type": "dpa1", + "rcut": 6.0, + "rcut_smth": 0.5, + "sel": 4, + "neuron": [4, 8, 16], + "axis_neuron": 4, + "seed": 1, + }, + }, + "models": [ + { + "type_map": "type_map_all", + "descriptor": "dpa1_descriptor", + "fitting_net": { + "neuron": [8, 8, 8], + "resnet_dt": True, + "seed": 1, + }, + }, + { + "type_map": "type_map_all", + "descriptor": "dpa1_descriptor", + "fitting_net": { + "neuron": [8, 8, 8], + "resnet_dt": True, + "seed": 2, + }, + }, + ], + "weights": "mean", + } + + model = get_model(config) + + self.assertEqual(model.get_type_map(), ["O", "H"]) + self.assertIsNotNone(model.shared_links) + self.assertIn("dpa1_descriptor", model.shared_links) + self.assertEqual(len(model.atomic_model.models), 2) + descriptor0 = model.atomic_model.models[0].descriptor + descriptor1 = model.atomic_model.models[1].descriptor + self.assert_dpa1_descriptor_shared(descriptor0, descriptor1) + + def test_shared_dict_descriptor_with_top_level_type_map(self) -> None: + config = { + "type": "linear_ener", + "type_map": ["O", "H"], + "shared_dict": { + "dpa1_descriptor": { + "type": "dpa1", + "rcut": 6.0, + "rcut_smth": 0.5, + "sel": 4, + "neuron": [4, 8, 16], + "axis_neuron": 4, + "seed": 1, + }, + }, + "models": [ + { + "descriptor": "dpa1_descriptor", + "fitting_net": { + "neuron": [8, 8, 8], + "resnet_dt": True, + "seed": 1, + }, + }, + { + "descriptor": "dpa1_descriptor", + "fitting_net": { + "neuron": [8, 8, 8], + "resnet_dt": True, + "seed": 2, + }, + }, + ], + "weights": "mean", + } + + model = get_model(config) + + self.assertEqual(model.get_type_map(), ["O", "H"]) + descriptor0 = model.atomic_model.models[0].descriptor + descriptor1 = model.atomic_model.models[1].descriptor + self.assert_dpa1_descriptor_shared(descriptor0, descriptor1) + + def test_shared_dict_hybrid_descriptor_component(self) -> None: + config = { + "type": "linear_ener", + "type_map": ["O", "H"], + "shared_dict": { + "dpa1_descriptor": self.make_dpa1_descriptor(seed=1), + }, + "models": [ + { + "descriptor": { + "type": "hybrid", + "list": [ + "dpa1_descriptor", + self.make_dpa1_descriptor(seed=2), + ], + }, + "fitting_net": self.make_fitting_net(seed=1), + }, + { + "descriptor": { + "type": "hybrid", + "list": [ + "dpa1_descriptor", + self.make_dpa1_descriptor(seed=3), + ], + }, + "fitting_net": self.make_fitting_net(seed=2), + }, + ], + "weights": "mean", + } + + model = get_model(config) + + self.assertIn("dpa1_descriptor", model.shared_links) + descriptor0 = model.atomic_model.models[0].descriptor.descrpt_list[0] + descriptor1 = model.atomic_model.models[1].descriptor.descrpt_list[0] + self.assert_dpa1_descriptor_shared(descriptor0, descriptor1) + + @patch("deepmd.pt.utils.update_sel.UpdateSel.get_nbor_stat") + def test_shared_dict_update_sel_round_trip(self, sel_mock) -> None: + sel_mock.return_value = 0.25, [10, 20] + config = { + "type": "linear_ener", + "shared_dict": { + "type_map_all": ["O", "H"], + "shared_descriptor": { + "type": "se_e2_a", + "rcut": 6.0, + "sel": "auto", + }, + }, + "models": [ + { + "type_map": "type_map_all", + "descriptor": { + "type": "hybrid", + "list": [ + "shared_descriptor", + { + "type": "se_e2_a", + "rcut": 6.0, + "sel": "auto:1.5", + }, + ], + }, + "fitting_net": { + "neuron": [8, 8, 8], + "resnet_dt": True, + "seed": 1, + }, + }, + ], + "weights": "mean", + } + + updated, min_nbor_dist = BaseModel.update_sel(None, None, config) + + self.assertEqual(min_nbor_dist, 0.25) + self.assertEqual(updated["type_map"], ["O", "H"]) + self.assertEqual( + updated["models"][0]["descriptor"]["list"][0], "shared_descriptor" + ) + self.assertEqual(updated["shared_dict"]["shared_descriptor"]["sel"], [12, 24]) + self.assertEqual(updated["models"][0]["descriptor"]["list"][1]["sel"], [16, 32]) + + @patch("deepmd.pt.utils.update_sel.UpdateSel.get_nbor_stat") + def test_shared_dict_update_sel_string_and_inline_descriptors( + self, sel_mock + ) -> None: + sel_mock.return_value = 0.25, [10, 20] + config = { + "type": "linear_ener", + "type_map": ["O", "H"], + "shared_dict": { + "shared_descriptor": self.make_se_e2_a_descriptor(sel="auto"), + }, + "models": [ + { + "descriptor": "shared_descriptor", + "fitting_net": self.make_fitting_net(seed=1), + }, + { + "descriptor": self.make_se_e2_a_descriptor(sel="auto:1.5"), + "fitting_net": self.make_fitting_net(seed=2), + }, + ], + "weights": "mean", + } + + updated, min_nbor_dist = BaseModel.update_sel(None, None, config) + + self.assertEqual(min_nbor_dist, 0.25) + self.assertEqual(updated["models"][0]["descriptor"], "shared_descriptor") + self.assertEqual(updated["shared_dict"]["shared_descriptor"]["sel"], [12, 24]) + self.assertEqual(updated["models"][1]["descriptor"]["sel"], [16, 32]) + + def test_shared_dict_fitting_net(self) -> None: + config = { + "type": "linear_ener", + "type_map": ["O", "H"], + "shared_dict": { + "shared_fit": { + "neuron": [8, 8, 8], + "resnet_dt": True, + "seed": 1, + }, + }, + "models": [ + { + "descriptor": { + "type": "dpa1", + "sel": 4, + "rcut_smth": 0.5, + "rcut": 6.0, + "neuron": [4, 8, 16], + "axis_neuron": 4, + "seed": 1, + }, + "fitting_net": "shared_fit", + }, + { + "descriptor": { + "type": "dpa1", + "sel": 4, + "rcut_smth": 0.5, + "rcut": 6.0, + "neuron": [4, 8, 16], + "axis_neuron": 4, + "seed": 2, + }, + "fitting_net": "shared_fit", + }, + ], + "weights": "mean", + } + + model = get_model(config) + + self.assertIsNotNone(model.shared_links) + self.assertIn("shared_fit", model.shared_links) + fitting0 = model.atomic_model.models[0].fitting_net + fitting1 = model.atomic_model.models[1].fitting_net + self.assertGreater(len(fitting0._modules), 0) + for module_name, module in fitting0._modules.items(): + self.assertIs(fitting1._modules[module_name], module) + + def test_shared_dict_requires_sub_model_type_map_without_top_level(self) -> None: + config = { + "type": "linear_ener", + "shared_dict": { + "shared_descriptor": self.make_dpa1_descriptor(seed=1), + }, + "models": [ + { + "type_map": ["O", "H"], + "descriptor": "shared_descriptor", + "fitting_net": self.make_fitting_net(seed=1), + }, + { + "descriptor": "shared_descriptor", + "fitting_net": self.make_fitting_net(seed=2), + }, + ], + "weights": "mean", + } + + with self.assertRaisesRegex( + ValueError, "Linear sub-model 1 must define type_map" + ): + get_model(config) + + def test_shared_dict_rejects_inconsistent_sub_model_type_map(self) -> None: + config = { + "type": "linear_ener", + "shared_dict": { + "shared_descriptor": self.make_dpa1_descriptor(seed=1), + }, + "models": [ + { + "type_map": ["O", "H"], + "descriptor": "shared_descriptor", + "fitting_net": self.make_fitting_net(seed=1), + }, + { + "type_map": ["H", "O"], + "descriptor": "shared_descriptor", + "fitting_net": self.make_fitting_net(seed=2), + }, + ], + "weights": "mean", + } + + with self.assertRaisesRegex( + ValueError, + "Linear sub-model 1 type_map differs from sub-model 0", + ): + get_model(config)