-
Notifications
You must be signed in to change notification settings - Fork 623
feat(pt): support shared_dict in linear energy model #5548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
276ad86
9127c12
7a7cd38
cc6fa8c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,8 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| import logging | ||
| from copy import ( | ||
| deepcopy, | ||
| ) | ||
| from typing import ( | ||
| Any, | ||
| ) | ||
|
|
@@ -14,6 +18,9 @@ | |
| from deepmd.pt.model.model.model import ( | ||
| BaseModel, | ||
| ) | ||
| from deepmd.pt.utils.multi_task import ( | ||
| preprocess_shared_params, | ||
| ) | ||
| from deepmd.utils.data_system import ( | ||
| DeepmdDataSystem, | ||
| ) | ||
|
|
@@ -25,6 +32,8 @@ | |
| make_model, | ||
| ) | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
| DPLinearModel_ = make_model(LinearEnergyAtomicModel) | ||
|
|
||
|
|
||
|
|
@@ -39,6 +48,108 @@ def __init__( | |
| ) -> None: | ||
| super().__init__(*args, **kwargs) | ||
|
|
||
| def share_params( | ||
| self, | ||
| shared_links: dict[str, Any], | ||
| model_key_prob_map: dict[str, float] | None = None, | ||
| data_stat_protect: float = 1e-2, | ||
| resume: bool = False, | ||
| ) -> None: | ||
| """Share parameters between linear sub-models. | ||
|
|
||
| ``shared_links`` follows the same structure as the multi-task | ||
| preprocessor. Linear sub-model keys are named ``model_0``, ``model_1``, | ||
| ... by ``get_linear_model``. | ||
| """ | ||
|
|
||
| def get_sub_model(model_key: str): # noqa: ANN202 | ||
| if not model_key.startswith("model_"): | ||
| raise RuntimeError(f"Unknown linear model key {model_key}!") | ||
| model_index = int(model_key.removeprefix("model_")) | ||
| return self.atomic_model.models[model_index] | ||
|
|
||
| def get_descriptor_class(model_key: str, shared_type: str): # noqa: ANN202 | ||
| sub_model = get_sub_model(model_key) | ||
| if shared_type == "descriptor": | ||
| return sub_model.descriptor | ||
| if "hybrid" in shared_type: | ||
| hybrid_index = int(shared_type.split("_")[-1]) | ||
| return sub_model.descriptor.descrpt_list[hybrid_index] | ||
| raise RuntimeError(f"Unknown class_type {shared_type}!") | ||
|
|
||
| for shared_item in shared_links: | ||
| shared_base = shared_links[shared_item]["links"][0] | ||
| class_type_base = shared_base["shared_type"] | ||
| model_key_base = shared_base["model_key"] | ||
| shared_level_base = int(shared_base["shared_level"]) | ||
| previous_shared_level = shared_level_base | ||
| if "descriptor" in class_type_base: | ||
| base_class = get_descriptor_class(model_key_base, class_type_base) | ||
| for link_item in shared_links[shared_item]["links"][1:]: | ||
| class_type_link = link_item["shared_type"] | ||
| model_key_link = link_item["model_key"] | ||
| shared_level_link = int(link_item["shared_level"]) | ||
| if shared_level_link < previous_shared_level: | ||
| raise ValueError( | ||
| "The shared_links must be sorted by shared_level!" | ||
| ) | ||
| previous_shared_level = shared_level_link | ||
| if "descriptor" not in class_type_link: | ||
| raise ValueError( | ||
| f"Class type mismatched: {class_type_base} vs {class_type_link}!" | ||
| ) | ||
| link_class = get_descriptor_class(model_key_link, class_type_link) | ||
| link_class.share_params( | ||
| base_class, shared_level_link, resume=resume | ||
| ) | ||
| log.warning( | ||
| "Shared params of %s.%s and %s.%s!", | ||
| model_key_base, | ||
| class_type_base, | ||
| model_key_link, | ||
| class_type_link, | ||
| ) | ||
| else: | ||
| base_model = get_sub_model(model_key_base) | ||
| if hasattr(base_model, class_type_base): | ||
| base_class = getattr(base_model, class_type_base) | ||
| for link_item in shared_links[shared_item]["links"][1:]: | ||
| class_type_link = link_item["shared_type"] | ||
| model_key_link = link_item["model_key"] | ||
| shared_level_link = int(link_item["shared_level"]) | ||
| if shared_level_link < previous_shared_level: | ||
| raise ValueError( | ||
| "The shared_links must be sorted by shared_level!" | ||
| ) | ||
| previous_shared_level = shared_level_link | ||
| if class_type_base != class_type_link: | ||
| raise ValueError( | ||
| f"Class type mismatched: {class_type_base} vs {class_type_link}!" | ||
| ) | ||
| link_model = get_sub_model(model_key_link) | ||
| link_class = getattr(link_model, class_type_link) | ||
| if model_key_prob_map is None: | ||
| frac_prob = 1.0 | ||
| else: | ||
| frac_prob = ( | ||
| model_key_prob_map[model_key_link] | ||
| / model_key_prob_map[model_key_base] | ||
| ) | ||
| link_class.share_params( | ||
| base_class, | ||
| shared_level_link, | ||
| model_prob=frac_prob, | ||
| protection=data_stat_protect, | ||
| resume=resume, | ||
| ) | ||
| log.warning( | ||
| "Shared params of %s.%s and %s.%s!", | ||
| model_key_base, | ||
| class_type_base, | ||
| model_key_link, | ||
| class_type_link, | ||
| ) | ||
|
|
||
| def translated_output_def(self) -> dict[str, OutputVariableDef]: | ||
| out_def_data = self.model_output_def().get_data() | ||
| output_def = { | ||
|
|
@@ -159,14 +270,71 @@ def update_sel( | |
| float | ||
| The minimum distance between two atoms | ||
| """ | ||
| local_jdata_cpy = local_jdata.copy() | ||
| local_jdata_cpy = deepcopy(local_jdata) | ||
| original_models = deepcopy(local_jdata_cpy["models"]) | ||
| has_shared_dict = "shared_dict" in local_jdata_cpy | ||
| if has_shared_dict: | ||
| shared_config = { | ||
| "model_dict": { | ||
| f"model_{idx}": sub_model | ||
| for idx, sub_model in enumerate(local_jdata_cpy["models"]) | ||
| }, | ||
| "shared_dict": local_jdata_cpy.get("shared_dict", {}), | ||
| } | ||
| if "type_map" in local_jdata_cpy: | ||
| shared_config["type_map"] = deepcopy(local_jdata_cpy["type_map"]) | ||
| shared_config, _ = preprocess_shared_params( | ||
| shared_config, | ||
| require_shared_type_map=False, | ||
| ) | ||
| local_jdata_cpy["models"] = list(shared_config["model_dict"].values()) | ||
| if "type_map" not in local_jdata_cpy: | ||
| local_jdata_cpy["type_map"] = deepcopy( | ||
| local_jdata_cpy["models"][0]["type_map"] | ||
| ) | ||
| type_map = local_jdata_cpy["type_map"] | ||
| min_nbor_dist = None | ||
| for idx, sub_model in enumerate(local_jdata_cpy["models"]): | ||
| if "tab_file" not in sub_model: | ||
| sub_model, temp_min = DPModelCommon.update_sel( | ||
| train_data, type_map, local_jdata["models"][idx] | ||
| sub_type_map = sub_model.get("type_map", type_map) | ||
| local_jdata_cpy["models"][idx], temp_min = DPModelCommon.update_sel( | ||
| train_data, sub_type_map, sub_model | ||
| ) | ||
| if min_nbor_dist is None or temp_min <= min_nbor_dist: | ||
| min_nbor_dist = temp_min | ||
| return local_jdata_cpy, min_nbor_dist | ||
| if not has_shared_dict: | ||
| return local_jdata_cpy, min_nbor_dist | ||
|
|
||
| def get_shared_key(shared_ref: str) -> str: | ||
| return shared_ref.split(":", maxsplit=1)[0] | ||
|
|
||
| ret_jdata = deepcopy(local_jdata) | ||
| ret_jdata["models"] = original_models | ||
| if "type_map" not in ret_jdata: | ||
| ret_jdata["type_map"] = deepcopy(type_map) | ||
| for idx, original_sub_model in enumerate(original_models): | ||
| if "tab_file" in original_sub_model: | ||
| continue | ||
| updated_sub_model = local_jdata_cpy["models"][idx] | ||
| descriptor_ref = original_sub_model.get("descriptor") | ||
| if isinstance(descriptor_ref, str): | ||
| ret_jdata["shared_dict"][get_shared_key(descriptor_ref)] = ( | ||
| updated_sub_model["descriptor"] | ||
| ) | ||
| elif ( | ||
| isinstance(descriptor_ref, dict) | ||
| and descriptor_ref.get("type") == "hybrid" | ||
| ): | ||
| updated_descriptor = updated_sub_model["descriptor"] | ||
| for hybrid_idx, hybrid_ref in enumerate(descriptor_ref["list"]): | ||
| if isinstance(hybrid_ref, str): | ||
| ret_jdata["shared_dict"][get_shared_key(hybrid_ref)] = ( | ||
| updated_descriptor["list"][hybrid_idx] | ||
| ) | ||
| else: | ||
| ret_jdata["models"][idx]["descriptor"]["list"][hybrid_idx] = ( | ||
| updated_descriptor["list"][hybrid_idx] | ||
| ) | ||
| else: | ||
| ret_jdata["models"][idx]["descriptor"] = updated_sub_model["descriptor"] | ||
|
Comment on lines
+320
to
+339
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The descriptor write-back has three branches: a string ref ( |
||
| return ret_jdata, min_nbor_dist | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new
type_mapvalidation raises in two cases — a sub-model missingtype_map(L180-182) and a sub-model whosetype_mapdiffers from sub-model 0 (L186-189) — but neither raise is covered. The tests only hit the happy path where all sub-models agree. CLAUDE.md: "When adding a new feature or API, provide tests that exercise every reachable code path." Consider twoassertRaisescases. Non-blocking.