Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 50 additions & 6 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from deepmd.pt.model.task.sezm_ener import (
SeZMEnergyFittingNet,
)
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.utils.spin import (
Spin,
)
Expand Down Expand Up @@ -155,19 +158,56 @@ def get_spin_model(model_params: dict) -> SpinModel:
def get_linear_model(model_params: dict) -> LinearEnergyModel:
model_params = copy.deepcopy(model_params)
weights = model_params.get("weights", "mean")
shared_links = None
if "shared_dict" in model_params:
shared_config = {
"model_dict": {
f"model_{idx}": sub_model
for idx, sub_model in enumerate(model_params["models"])
},
"shared_dict": model_params.get("shared_dict", {}),
}
if "type_map" in model_params:
shared_config["type_map"] = copy.deepcopy(model_params["type_map"])
shared_config, shared_links = preprocess_shared_params(
shared_config,
require_shared_type_map=False,
)
model_params["models"] = list(shared_config["model_dict"].values())
if "type_map" not in model_params:
for idx, sub_model_params in enumerate(model_params["models"]):
if "type_map" not in sub_model_params:
raise ValueError(
f"Linear sub-model {idx} must define type_map when "
"linear_ener has no top-level type_map."
)
first_type_map = model_params["models"][0]["type_map"]
for idx, sub_model_params in enumerate(model_params["models"][1:], start=1):
if sub_model_params["type_map"] != first_type_map:
raise ValueError(
f"Linear sub-model {idx} type_map differs from sub-model 0. "
"All type_map values must be identical when linear_ener "
Comment on lines +179 to +189

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new type_map validation raises in two cases — a sub-model missing type_map (L180-182) and a sub-model whose type_map differs from sub-model 0 (L186-189) — but neither raise is covered. The tests only hit the happy path where all sub-models agree. CLAUDE.md: "When adding a new feature or API, provide tests that exercise every reachable code path." Consider two assertRaises cases. Non-blocking.

"has no top-level type_map."
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
model_params["type_map"] = copy.deepcopy(first_type_map)

list_of_models = []
ntypes = len(model_params["type_map"])
for sub_model_params in model_params["models"]:
if "type_map" not in sub_model_params:
sub_model_params["type_map"] = model_params["type_map"]
if "descriptor" in sub_model_params:
# descriptor
sub_model_params["descriptor"]["ntypes"] = ntypes
sub_ntypes = len(sub_model_params["type_map"])
sub_model_params["descriptor"]["ntypes"] = sub_ntypes
descriptor, fitting, _ = _get_standard_model_components(
sub_model_params, ntypes
sub_model_params, sub_ntypes
)
list_of_models.append(
DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"])
DPAtomicModel(
descriptor,
fitting,
type_map=copy.deepcopy(sub_model_params["type_map"]),
)
)

else: # must be pairtab
Expand All @@ -179,19 +219,23 @@ def get_linear_model(model_params: dict) -> LinearEnergyModel:
sub_model_params["tab_file"],
sub_model_params["rcut"],
sub_model_params["sel"],
type_map=model_params["type_map"],
type_map=copy.deepcopy(sub_model_params["type_map"]),
)
)

atom_exclude_types = model_params.get("atom_exclude_types", [])
pair_exclude_types = model_params.get("pair_exclude_types", [])
return LinearEnergyModel(
model = LinearEnergyModel(
models=list_of_models,
type_map=model_params["type_map"],
weights=weights,
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
)
model.shared_links = shared_links
if shared_links:
model.share_params(shared_links, resume=True)
return model


def get_zbl_model(model_params: dict) -> DPZBLModel:
Expand Down
176 changes: 172 additions & 4 deletions deepmd/pt/model/model/dp_linear_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from copy import (
deepcopy,
)
from typing import (
Any,
)
Expand All @@ -14,6 +18,9 @@
from deepmd.pt.model.model.model import (
BaseModel,
)
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand All @@ -25,6 +32,8 @@
make_model,
)

log = logging.getLogger(__name__)

DPLinearModel_ = make_model(LinearEnergyAtomicModel)


Expand All @@ -39,6 +48,108 @@ def __init__(
) -> None:
super().__init__(*args, **kwargs)

def share_params(
self,
shared_links: dict[str, Any],
model_key_prob_map: dict[str, float] | None = None,
data_stat_protect: float = 1e-2,
resume: bool = False,
) -> None:
"""Share parameters between linear sub-models.

``shared_links`` follows the same structure as the multi-task
preprocessor. Linear sub-model keys are named ``model_0``, ``model_1``,
... by ``get_linear_model``.
"""

def get_sub_model(model_key: str): # noqa: ANN202
if not model_key.startswith("model_"):
raise RuntimeError(f"Unknown linear model key {model_key}!")
model_index = int(model_key.removeprefix("model_"))
return self.atomic_model.models[model_index]

def get_descriptor_class(model_key: str, shared_type: str): # noqa: ANN202
sub_model = get_sub_model(model_key)
if shared_type == "descriptor":
return sub_model.descriptor
if "hybrid" in shared_type:
hybrid_index = int(shared_type.split("_")[-1])
return sub_model.descriptor.descrpt_list[hybrid_index]
raise RuntimeError(f"Unknown class_type {shared_type}!")

for shared_item in shared_links:
shared_base = shared_links[shared_item]["links"][0]
class_type_base = shared_base["shared_type"]
model_key_base = shared_base["model_key"]
shared_level_base = int(shared_base["shared_level"])
previous_shared_level = shared_level_base
if "descriptor" in class_type_base:
base_class = get_descriptor_class(model_key_base, class_type_base)
for link_item in shared_links[shared_item]["links"][1:]:
class_type_link = link_item["shared_type"]
model_key_link = link_item["model_key"]
shared_level_link = int(link_item["shared_level"])
if shared_level_link < previous_shared_level:
raise ValueError(
"The shared_links must be sorted by shared_level!"
)
previous_shared_level = shared_level_link
if "descriptor" not in class_type_link:
raise ValueError(
f"Class type mismatched: {class_type_base} vs {class_type_link}!"
)
link_class = get_descriptor_class(model_key_link, class_type_link)
link_class.share_params(
base_class, shared_level_link, resume=resume
)
log.warning(
"Shared params of %s.%s and %s.%s!",
model_key_base,
class_type_base,
model_key_link,
class_type_link,
)
else:
base_model = get_sub_model(model_key_base)
if hasattr(base_model, class_type_base):
base_class = getattr(base_model, class_type_base)
for link_item in shared_links[shared_item]["links"][1:]:
class_type_link = link_item["shared_type"]
model_key_link = link_item["model_key"]
shared_level_link = int(link_item["shared_level"])
if shared_level_link < previous_shared_level:
raise ValueError(
"The shared_links must be sorted by shared_level!"
)
previous_shared_level = shared_level_link
if class_type_base != class_type_link:
raise ValueError(
f"Class type mismatched: {class_type_base} vs {class_type_link}!"
)
link_model = get_sub_model(model_key_link)
link_class = getattr(link_model, class_type_link)
if model_key_prob_map is None:
frac_prob = 1.0
else:
frac_prob = (
model_key_prob_map[model_key_link]
/ model_key_prob_map[model_key_base]
)
link_class.share_params(
base_class,
shared_level_link,
model_prob=frac_prob,
protection=data_stat_protect,
resume=resume,
)
log.warning(
"Shared params of %s.%s and %s.%s!",
model_key_base,
class_type_base,
model_key_link,
class_type_link,
)

def translated_output_def(self) -> dict[str, OutputVariableDef]:
out_def_data = self.model_output_def().get_data()
output_def = {
Expand Down Expand Up @@ -159,14 +270,71 @@ def update_sel(
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy = deepcopy(local_jdata)
original_models = deepcopy(local_jdata_cpy["models"])
has_shared_dict = "shared_dict" in local_jdata_cpy
if has_shared_dict:
shared_config = {
"model_dict": {
f"model_{idx}": sub_model
for idx, sub_model in enumerate(local_jdata_cpy["models"])
},
"shared_dict": local_jdata_cpy.get("shared_dict", {}),
}
if "type_map" in local_jdata_cpy:
shared_config["type_map"] = deepcopy(local_jdata_cpy["type_map"])
shared_config, _ = preprocess_shared_params(
shared_config,
require_shared_type_map=False,
)
local_jdata_cpy["models"] = list(shared_config["model_dict"].values())
if "type_map" not in local_jdata_cpy:
local_jdata_cpy["type_map"] = deepcopy(
local_jdata_cpy["models"][0]["type_map"]
)
type_map = local_jdata_cpy["type_map"]
min_nbor_dist = None
for idx, sub_model in enumerate(local_jdata_cpy["models"]):
if "tab_file" not in sub_model:
sub_model, temp_min = DPModelCommon.update_sel(
train_data, type_map, local_jdata["models"][idx]
sub_type_map = sub_model.get("type_map", type_map)
local_jdata_cpy["models"][idx], temp_min = DPModelCommon.update_sel(
train_data, sub_type_map, sub_model
)
if min_nbor_dist is None or temp_min <= min_nbor_dist:
min_nbor_dist = temp_min
return local_jdata_cpy, min_nbor_dist
if not has_shared_dict:
return local_jdata_cpy, min_nbor_dist

def get_shared_key(shared_ref: str) -> str:
return shared_ref.split(":", maxsplit=1)[0]

ret_jdata = deepcopy(local_jdata)
ret_jdata["models"] = original_models
if "type_map" not in ret_jdata:
ret_jdata["type_map"] = deepcopy(type_map)
for idx, original_sub_model in enumerate(original_models):
if "tab_file" in original_sub_model:
continue
updated_sub_model = local_jdata_cpy["models"][idx]
descriptor_ref = original_sub_model.get("descriptor")
if isinstance(descriptor_ref, str):
ret_jdata["shared_dict"][get_shared_key(descriptor_ref)] = (
updated_sub_model["descriptor"]
)
elif (
isinstance(descriptor_ref, dict)
and descriptor_ref.get("type") == "hybrid"
):
updated_descriptor = updated_sub_model["descriptor"]
for hybrid_idx, hybrid_ref in enumerate(descriptor_ref["list"]):
if isinstance(hybrid_ref, str):
ret_jdata["shared_dict"][get_shared_key(hybrid_ref)] = (
updated_descriptor["list"][hybrid_idx]
)
else:
ret_jdata["models"][idx]["descriptor"]["list"][hybrid_idx] = (
updated_descriptor["list"][hybrid_idx]
)
else:
ret_jdata["models"][idx]["descriptor"] = updated_sub_model["descriptor"]
Comment on lines +320 to +339

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The descriptor write-back has three branches: a string ref (isinstance(descriptor_ref, str), L320), a hybrid dict (L324), and an inline dict (the final else, L339). test_shared_dict_update_sel_round_trip only uses a hybrid descriptor, so only the middle branch runs; the string-ref and inline-dict branches are untested. CLAUDE.md: "UTs should cover all code paths, including both branches of boolean conditions." Non-blocking.

return ret_jdata, min_nbor_dist
20 changes: 16 additions & 4 deletions deepmd/pt/utils/multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,17 @@ def _cascade_top_level_defaults(model_config: dict[str, Any]) -> None:

def preprocess_shared_params(
model_config: dict[str, Any],
require_shared_type_map: bool = True,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Preprocess the model params for multitask model, and generate the links dict for further sharing.
"""Preprocess shared model params and generate links for parameter sharing.

Args:
model_config: Model params of multitask model.
model_config: Model params containing ``model_dict`` and optional
``shared_dict``.
require_shared_type_map: Whether exactly one ``type_map`` entry must be
referenced from ``shared_dict``. Multi-task training keeps this
requirement; linear models may inherit an explicit top-level
``type_map`` instead.

Returns
-------
Expand Down Expand Up @@ -167,7 +173,10 @@ def replace_one_item(
item_params = model_params_item[item_key]
if isinstance(item_params, str):
replace_one_item(model_params_item, item_key, item_params)
elif item_params.get("type", "") == "hybrid":
elif (
isinstance(item_params, dict)
and item_params.get("type", "") == "hybrid"
):
for ii, hybrid_item in enumerate(item_params["list"]):
if isinstance(hybrid_item, str):
replace_one_item(
Expand All @@ -187,7 +196,10 @@ def replace_one_item(
)
# little trick to make spin models in the front to be the base models,
# because its type embeddings are more general.
assert len(type_map_keys) == 1, "Multitask model must have only one type_map!"
if require_shared_type_map:
assert len(type_map_keys) == 1, "Multitask model must have only one type_map!"
else:
assert len(type_map_keys) <= 1, "Shared params must have at most one type_map!"
return model_config, shared_links


Expand Down
4 changes: 4 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3388,6 +3388,7 @@ def linear_ener_model_args() -> Argument:
'If "mean", the weights are set to be 1 / len(models). '
'If "sum", the weights are set to be 1.'
)
doc_shared_dict = "The definition of the shared parameters used in the `models` within linear model."
models_args = model_args(exclude_hybrid=True)
models_args.name = "models"
models_args.fold_subdoc = True
Expand All @@ -3405,6 +3406,9 @@ def linear_ener_model_args() -> Argument:
optional=False,
doc=doc_weights,
),
Argument(
"shared_dict", dict, optional=True, default={}, doc=doc_shared_dict
),
],
doc=doc_only_tf_supported,
)
Expand Down
Loading
Loading