feat(pt_expt): support .pt training checkpoints in DeepEval#5423
feat(pt_expt): support .pt training checkpoints in DeepEval#5423njzjz merged 9 commits intodeepmodeling:masterfrom
Conversation
`dp --pt-expt test -m foo.pt` previously rejected `.pt` files (only `.pt2` / `.pte` were supported), and `dp --pt test -m foo.pt` on a pt_expt-trained checkpoint silently loaded random weights because the state-dict layout (dpmodel `.w`/`.b` keys) doesn't match the legacy pt backend's expectations. - `Backend.detect_backend_by_model` sniffs `.pt` content so files with `.w`/`.b` keys (pt_expt) route to the pt_expt DeepEval and files with `.matrix`/`.bias` keys (pt) keep routing to pt. - `pt_expt.DeepEval._load_pt` reconstructs the model from `_extra_state["model_params"]`, loads the state-dict via `ModelWrapper`, and exposes an eager `forward_common_lower` runner with the same signature as the AOTI/exported module so the existing `eval()` path is unchanged. Spin-aware and non-spin variants; multi-task `.pt` selects a head and remaps keys. - `pt_expt.get_model` learns `get_spin_model` (mirrors dpmodel) so spin checkpoints can be reconstructed from `model_params`. - Tests cover dispatch sniffing, single-task / multi-task / spin / spin-multi-task `.pt` parity vs eager forward, fparam / aparam, and `.pt` vs `.pte` cross-format consistency at 1e-10.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 98aee78a86
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant BackendDetection as Backend Detection
participant TorchLoad as torch.load(...)
participant DeepEval as DeepEval Loader
participant ModelBuilder as Model Construction
participant Inference as Inference Engine
User->>BackendDetection: detect_backend_by_model("model.pt")
BackendDetection->>TorchLoad: attempt torch.load(checkpoint, weights_only=True)
TorchLoad-->>BackendDetection: checkpoint dict / error
BackendDetection->>BackendDetection: inspect state_dict key suffixes
BackendDetection-->>User: return backend ("pt" or "pt-expt")
User->>DeepEval: DeepEval(model_file="model.pt", head=?)
DeepEval->>TorchLoad: torch.load(checkpoint, weights_only=True)
TorchLoad-->>DeepEval: checkpoint dict
DeepEval->>DeepEval: select head, remap state_dict keys
DeepEval->>ModelBuilder: get_model(config) (standard or spin)
ModelBuilder-->>DeepEval: model instance
DeepEval->>DeepEval: load weights into model, build exported_module
User->>Inference: exported_module(forward args)
Inference-->>User: energy, forces, virial, (atomic/spin outputs)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (5)
deepmd/pt_expt/infer/deep_eval.py (2)
252-260:str.replaceis unbounded; hoistprefixout of the loop.Two minor robustness/readability issues in the head-renaming loop:
key.replace(prefix, "model.Default.")rewrites every occurrence ofprefixin the key. If a head name happens to appear deeper in a key path (or in any state-dict key derived from a user-supplied identifier), keys silently get double-rewritten. Slice the prefix instead.prefix = f"model.{head}."is recomputed on every iteration.♻️ Proposed fix
# Restrict state_dict to the chosen head and rename to "Default". head_state = {"_extra_state": state_dict["_extra_state"]} + prefix = f"model.{head}." for key, value in state_dict.items(): - prefix = f"model.{head}." if key.startswith(prefix): - head_state[key.replace(prefix, "model.Default.")] = ( + new_key = "model.Default." + key[len(prefix) :] + head_state[new_key] = ( value.clone() if torch.is_tensor(value) else value ) state_dict = head_state🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/infer/deep_eval.py` around lines 252 - 260, The loop that renames head keys is unsafe because it recomputes prefix each iteration and uses key.replace(prefix, ...) which can replace multiple occurrences; hoist prefix = f"model.{head}." out of the for loop and when a key startswith(prefix) build the new key by slicing off the prefix (e.g. "model.Default." + key[len(prefix):]) instead of using key.replace; keep cloning tensor values as done currently and assign into head_state, leaving all other logic (state_dict, head_state, torch.is_tensor) unchanged.
223-225: InconsistentDEVICEimport: should bedeepmd.pt_expt.utils.env.This file already imports
DEVICEfromdeepmd.pt_expt.utils.env(line 813, 982). Pulling it fromdeepmd.pt.utils.envhere is inconsistent and creates an unnecessary dependency frompt_exptonpt. If the two backends ever diverge on device defaults this becomes a subtle bug.♻️ Proposed fix
- from deepmd.pt.utils.env import ( + from deepmd.pt_expt.utils.env import ( DEVICE, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/infer/deep_eval.py` around lines 223 - 225, The import of DEVICE at the top of deep_eval.py incorrectly references deepmd.pt.utils.env; change it to import DEVICE from deepmd.pt_expt.utils.env so it matches the other imports in this module (see existing imports of DEVICE around the file) and avoids creating a dependency on the pt backend—update the single import statement that currently references deepmd.pt.utils.env to reference deepmd.pt_expt.utils.env instead.source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py (2)
475-495: Preferstrict=Truein cross-format consistency zips.
dp.eval(...)returns a fixed-arity tuple matching the request defs, and the name lists in these consistency loops (here, lines 567-580, and 659-672) hard-code 7 entries to mirror the spin-with-atomic case. Withstrict=False, ifdp.evalever changes arity (e.g., a new output is added) the loop silently truncates and consistency for new fields is no longer asserted.strict=Truewould force the tests to be updated.- for name, a, b in zip( + for name, a, b in zip( ( "energy", ... "mask_mag", ), out_pt, out_pte, - strict=False, + strict=True, ):(Same applies at lines 567-580 and 659-672.)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 475 - 495, The zip used to iterate over ("energy", "force", "virial", "atom_energy", "atom_virial", "force_mag", "mask_mag") and the outputs out_pt and out_pte currently passes strict=False, which can silently drop any future-added outputs; change the zip call(s) that compare name, a, b (the one using the hard-coded 7-entry tuple with out_pt and out_pte) to use strict=True so mismatched arity raises an error—update all equivalent consistency loops that compare out_pt and out_pte.
401-407:os.rmdirwill fail on residual files; prefershutil.rmtree.
_make_spin_filesonly puts.ptand.pteintotmpdir, so today this works. But if a future change writes any auxiliary file (e.g., a sidecar.jsonfromdeserialize_to_file),os.rmdirwill raiseOSErrorand leak the directory.shutil.rmtree(cls.files["tmpdir"], ignore_errors=True)handles both this and any partial-creation cleanup uniformly.♻️ Proposed fix
+ import shutil + `@classmethod` def tearDownClass(cls) -> None: - for ext in (".pt", ".pte"): - path = cls.files[ext] - if os.path.exists(path): - os.unlink(path) - os.rmdir(cls.files["tmpdir"]) + shutil.rmtree(cls.files["tmpdir"], ignore_errors=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 401 - 407, The tearDownClass cleanup uses os.rmdir which will fail if any residual files remain; update the class teardown (tearDownClass) to use shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) instead of os.rmdir so the tmpdir is removed recursively and safely; be sure to import shutil at the top of the test module and keep the existing per-extension unlink loop (from _make_spin_files related cleanup) intact so both explicit files and any auxiliary sidecar files are cleaned up.deepmd/backend/backend.py (1)
122-124: Simplify backend lookup.
Backend.get_backends()already returns adict[str, type[Backend]], so the linear scan can be a single dict lookup:♻️ Proposed refactor
- for key, backend in Backend.get_backends().items(): - if key == target_name: - return backend + backend = Backend.get_backends().get(target_name) + if backend is not None: + return backend🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/backend/backend.py` around lines 122 - 124, The loop in backend.py that iterates over Backend.get_backends().items() to find a matching key is unnecessary because Backend.get_backends() already returns a dict; replace the linear scan in the lookup logic with a direct dictionary access (e.g., lookup = Backend.get_backends().get(target_name)) and return that result (or handle a missing key appropriately) instead of the for loop. Ensure you update the code paths that expect a backend when not found (raise or return None consistently) and keep references to Backend.get_backends() and target_name to locate the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/backend/backend.py`:
- Around line 108-127: The torch.load call in detect_backend_by_model currently
uses weights_only=False which can execute pickled code and cause RCE; change the
call in that block (the torch.load(...) invocation) to torch.load(filename,
map_location="cpu", weights_only=True) so we only load tensor weights when
sniffing backend, keeping the existing try/except fallback to suffix-based
detection; leave the later real model load in deepmd/pt_expt/infer/deep_eval.py
(which uses weights_only=False) unchanged.
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 251-309: The test builds two heads with identical random seeds so
weight selection isn’t actually validated; change _build_model_and_params to
accept a seed parameter (forward it into DescrptSeA and EnergyFittingNet), call
_build_model_and_params twice with distinct seeds when creating
cls.model_a/params_a and cls.model_b/params_b in setUpClass, and add an explicit
distinct-outputs assertion (like test_distinct_heads_produce_distinct_outputs)
that verifies the two heads produce different energies/forces for the same input
to ensure head selection actually loads different weights (update any callers of
_build_model_and_params accordingly).
---
Nitpick comments:
In `@deepmd/backend/backend.py`:
- Around line 122-124: The loop in backend.py that iterates over
Backend.get_backends().items() to find a matching key is unnecessary because
Backend.get_backends() already returns a dict; replace the linear scan in the
lookup logic with a direct dictionary access (e.g., lookup =
Backend.get_backends().get(target_name)) and return that result (or handle a
missing key appropriately) instead of the for loop. Ensure you update the code
paths that expect a backend when not found (raise or return None consistently)
and keep references to Backend.get_backends() and target_name to locate the
change.
In `@deepmd/pt_expt/infer/deep_eval.py`:
- Around line 252-260: The loop that renames head keys is unsafe because it
recomputes prefix each iteration and uses key.replace(prefix, ...) which can
replace multiple occurrences; hoist prefix = f"model.{head}." out of the for
loop and when a key startswith(prefix) build the new key by slicing off the
prefix (e.g. "model.Default." + key[len(prefix):]) instead of using key.replace;
keep cloning tensor values as done currently and assign into head_state, leaving
all other logic (state_dict, head_state, torch.is_tensor) unchanged.
- Around line 223-225: The import of DEVICE at the top of deep_eval.py
incorrectly references deepmd.pt.utils.env; change it to import DEVICE from
deepmd.pt_expt.utils.env so it matches the other imports in this module (see
existing imports of DEVICE around the file) and avoids creating a dependency on
the pt backend—update the single import statement that currently references
deepmd.pt.utils.env to reference deepmd.pt_expt.utils.env instead.
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 475-495: The zip used to iterate over ("energy", "force",
"virial", "atom_energy", "atom_virial", "force_mag", "mask_mag") and the outputs
out_pt and out_pte currently passes strict=False, which can silently drop any
future-added outputs; change the zip call(s) that compare name, a, b (the one
using the hard-coded 7-entry tuple with out_pt and out_pte) to use strict=True
so mismatched arity raises an error—update all equivalent consistency loops that
compare out_pt and out_pte.
- Around line 401-407: The tearDownClass cleanup uses os.rmdir which will fail
if any residual files remain; update the class teardown (tearDownClass) to use
shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) instead of os.rmdir so
the tmpdir is removed recursively and safely; be sure to import shutil at the
top of the test module and keep the existing per-extension unlink loop (from
_make_spin_files related cleanup) intact so both explicit files and any
auxiliary sidecar files are cleaned up.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: a018a9f1-5dda-4e71-86d2-9ffdf38f75a1
📒 Files selected for processing (4)
deepmd/backend/backend.pydeepmd/pt_expt/infer/deep_eval.pydeepmd/pt_expt/model/get_model.pysource/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5423 +/- ##
==========================================
+ Coverage 82.36% 82.42% +0.06%
==========================================
Files 824 824
Lines 87109 87509 +400
Branches 4197 4198 +1
==========================================
+ Hits 71743 72126 +383
- Misses 14091 14107 +16
- Partials 1275 1276 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Real training-produced `.pt` checkpoints have `model.{head}.original_model.X`
for the trained weights and `model.{head}.compiled_forward_lower.*`
for the compiled-graph constants. Previously `_load_pt` did a strict
`load_state_dict` against a plain `get_model(model_params)` and failed.
Fix: strip the `original_model.` infix and drop all
`compiled_forward_lower.*` keys before loading. Works for both
single-task and multi-task layouts. Tests synthesise the wrapped
layout directly to avoid a real `torch.compile` invocation.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py (1)
580-585: Useshutil.rmtreefor the tempdir cleanup.
os.rmdironly succeeds if the directory is empty; if any future change adds an extra artifact (e.g. a.lockfile from torch save, or a partial write on a failing test),tearDownClasswill raise and mask the actual test failure. Switching toshutil.rmtree(cls.files["tmpdir"], ignore_errors=True)makes cleanup robust without changing behavior in the happy path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 580 - 585, tearDownClass currently uses os.rmdir to remove cls.files["tmpdir"], which will fail if the directory is not empty; replace the os.rmdir call with shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) to make cleanup robust and avoid masking test failures, and ensure shutil is imported at top of the test module if not already; reference the tearDownClass method and the cls.files["tmpdir"] usage when making the change.deepmd/pt_expt/infer/deep_eval.py (1)
251-261: Constrain head-prefix replacement to the leading occurrence.
key.replace(prefix, "model.Default.")rewrites every occurrence ofmodel.{head}.in the key, not just the leading one. The loop already gated onstartswith(prefix), so this is harmless for current key shapes, but it's a defensive landmine if a head name (e.g."head") ever appears later in the key (e.g. nested module names). Safer to slice or pincount=1:♻️ Proposed fix
- head_state = {"_extra_state": state_dict["_extra_state"]} - for key, value in state_dict.items(): - prefix = f"model.{head}." - if key.startswith(prefix): - head_state[key.replace(prefix, "model.Default.")] = ( - value.clone() if torch.is_tensor(value) else value - ) + prefix = f"model.{head}." + head_state = {"_extra_state": state_dict["_extra_state"]} + for key, value in state_dict.items(): + if key.startswith(prefix): + new_key = "model.Default." + key[len(prefix):] + head_state[new_key] = ( + value.clone() if torch.is_tensor(value) else value + )Also moves
prefixout of the per-iteration body.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/infer/deep_eval.py` around lines 251 - 261, The replacement of the head prefix in the loop may replace non-leading occurrences; compute prefix = f"model.{head}." once before the loop, and when a key startswith(prefix) produce the new key by only removing the leading prefix (e.g., new_key = "model.Default." + key[len(prefix):] or use replace with count=1) before inserting into head_state, preserving the clone behavior for tensors and leaving "_extra_state" handling as-is (affects variables: head_params, state_dict, head_state, prefix).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Line 426: The tuple unpack from dp.eval yields unused variables that trigger
RUF059; update the unpackings to prefix unused names with an underscore (e.g.
change v → _v, av → _av, mm → _mm, ae → _ae) or assert against them where the
test should exercise those values (e.g. in
test_eval_pbc_atomic_matches_reference keep or assert on av when atomic=True),
ensuring all four occurrences (the dp.eval unpack at dp.eval(...), and the
unpack sites noted) are adjusted so ruff no longer reports unused-variable
warnings.
---
Nitpick comments:
In `@deepmd/pt_expt/infer/deep_eval.py`:
- Around line 251-261: The replacement of the head prefix in the loop may
replace non-leading occurrences; compute prefix = f"model.{head}." once before
the loop, and when a key startswith(prefix) produce the new key by only removing
the leading prefix (e.g., new_key = "model.Default." + key[len(prefix):] or use
replace with count=1) before inserting into head_state, preserving the clone
behavior for tensors and leaving "_extra_state" handling as-is (affects
variables: head_params, state_dict, head_state, prefix).
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 580-585: tearDownClass currently uses os.rmdir to remove
cls.files["tmpdir"], which will fail if the directory is not empty; replace the
os.rmdir call with shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) to
make cleanup robust and avoid masking test failures, and ensure shutil is
imported at top of the test module if not already; reference the tearDownClass
method and the cls.files["tmpdir"] usage when making the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: d919d944-10aa-4a7a-b592-de89cb024aa8
📒 Files selected for processing (2)
deepmd/pt_expt/infer/deep_eval.pysource/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py
The exported `.pte` and eager `.pt` paths produce identical energy / force / virial / atom_energy / force_mag / mask_mag outputs for spin models, but per-atom virial diverges. The reduced virial (which is the sum of per-atom virials including the virtual-atom contribution) still matches, so the divergence is in the per-extended-atom split, not the totals. Pin this as a known limitation; revisit once the export-time spin atom-virial path is reconciled with the eager path.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py (1)
579-585:os.rmdiris fragile for the spin scratch dir.
tearDownClassonly unlinks.pt/.ptethenos.rmdir(tmpdir). Ifdeserialize_to_fileever drops a sidecar (cache, journal,.pte/-as-directory layout, etc.) theos.rmdircall raisesOSErrorand leaks the temp tree.shutil.rmtree(cls.files["tmpdir"], ignore_errors=True)is both shorter and robust to layout changes.♻️ Proposed refactor
+import shutil @@ `@classmethod` def tearDownClass(cls) -> None: - for ext in (".pt", ".pte"): - path = cls.files[ext] - if os.path.exists(path): - os.unlink(path) - os.rmdir(cls.files["tmpdir"]) + shutil.rmtree(cls.files["tmpdir"], ignore_errors=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 579 - 585, tearDownClass currently unlinks only .pt/.pte files then calls os.rmdir(cls.files["tmpdir"]), which will raise OSError and leak the temp directory if any extra sidecar files or nested dirs exist (e.g., created by deserialize_to_file); replace the fragile os.rmdir call with shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) and add the shutil import so the tearDownClass cleanup always removes the entire tmpdir regardless of layout while remaining tolerant of errors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 579-585: tearDownClass currently unlinks only .pt/.pte files then
calls os.rmdir(cls.files["tmpdir"]), which will raise OSError and leak the temp
directory if any extra sidecar files or nested dirs exist (e.g., created by
deserialize_to_file); replace the fragile os.rmdir call with
shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) and add the shutil import
so the tearDownClass cleanup always removes the entire tmpdir regardless of
layout while remaining tolerant of errors.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 7cd4b238-4f71-4f89-8118-9021b50980c0
📒 Files selected for processing (1)
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds first-class support for loading .pt training checkpoints for inference in the pt_expt backend, including backend auto-detection for shared .pt suffixes.
Changes:
- Add
.ptsniffing inBackend.detect_backend_by_modelto route.ptfiles toptvspt-exptbased on state-dict key naming. - Implement
.ptcheckpoint loading inpt_expt.DeepEval(including multitask head selection, compiled-wrapper key cleanup, and eager runner shims). - Add a comprehensive pt_expt inference test suite covering routing, spin, multitask, aparam/fparam behavior, and
.pt↔.pteconsistency.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py | Adds end-to-end tests for .pt dispatch + pt_expt .pt inference correctness across single/multi-task and spin variants. |
| deepmd/pt_expt/model/get_model.py | Adds get_spin_model and updates get_model to construct spin models correctly from config. |
| deepmd/pt_expt/infer/deep_eval.py | Extends pt_expt inference to accept .pt checkpoints and reconstruct eager runners compatible with existing eval paths. |
| deepmd/backend/backend.py | Implements .pt content sniffing to disambiguate backend routing between pt and pt-expt. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
`Backend.detect_backend_by_model` and `pt_expt.DeepEval._load_pt` deserialised `.pt` files with `weights_only=False`, which allows arbitrary code execution from a malicious checkpoint. The training resume path (training.py:712) already uses `weights_only=True`; align the two new sites with that convention. Reported by chatgpt-codex-connector on PR deepmodeling#5423.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/backend/backend.py`:
- Around line 118-126: The current dispatch returns a backend for "pt" whenever
not has_pt_expt, which misroutes when neither or both patterns are present;
change the logic so you only pick a target_name and return a backend when
exactly one pattern matches: use the existing has_pt_expt and has_pt booleans to
detect exclusivity (has_pt_expt && !has_pt => target_name="pt-expt"; has_pt &&
!has_pt_expt => target_name="pt") and otherwise do not set target_name or return
from the backend loop so the suffix fallback can run; update the block around
has_pt_expt, has_pt, target_name and the Backend.get_backends().items() loop
accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 5759ed93-e1b8-4282-9636-67f8a2fa1e96
📒 Files selected for processing (2)
deepmd/backend/backend.pydeepmd/pt_expt/infer/deep_eval.py
✅ Files skipped from review due to trivial changes (1)
- deepmd/pt_expt/infer/deep_eval.py
CodeRabbit flagged that `TestPtExptLoadPtMultiTask` built both heads with the same `GLOBAL_SEED`, so `test_select_head_matches_single_task_forward` would still pass if `_load_pt` accidentally loaded the wrong head's weights. Mirror the spin variant: pass distinct seeds (42/7) to `_build_model_and_params` for the two heads, and add `test_distinct_heads_produce_distinct_outputs` as a sanity guard. Also prefix unused unpack vars with `_` to satisfy RUF059.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py (1)
559-564: Use a documented CPU override here.
torch.set_default_device()is documented to take a device/string argument, and the PyTorch docs recommendwith torch.device(device):when the override is temporary. PassingNonehere is undocumented and makes this helper version-sensitive. (docs.pytorch.org)♻️ Documented temporary-override form
- prev = torch.get_default_device() - torch.set_default_device(None) - try: + with torch.device("cpu"): deserialize_to_file(pte_path, data) - finally: - torch.set_default_device(prev)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 559 - 564, The test currently manipulates default device via torch.set_default_device(None) which is undocumented; replace this temporary override with the documented context manager form: use with torch.device("cpu"): to run deserialize_to_file(pte_path, data) so you no longer need prev/save/restore or torch.set_default_device(None); update the block around deserialize_to_file in test_deep_eval_pt_checkpoint.py to enter the CPU context (torch.device("cpu")) and call deserialize_to_file inside it, removing the prev/finally restore logic and any reference to torch.set_default_device(None).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 253-260: The tensors coord_t, atype_t, and cell_t are being
created with a hardcoded device=DEVICE which can mismatch the actual device of
self.model; change tensor creation to use the model's device instead (e.g.,
derive device via next(self.model.parameters()).device or from self.model if it
exposes a device) so the eager reference forward call self.model.forward(...)
runs on the same device. Apply the same fix to the other multi-head reference
blocks (the other places that build eager-reference tensors mentioned in the
comment) so all reference tensors use the model's device.
---
Nitpick comments:
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 559-564: The test currently manipulates default device via
torch.set_default_device(None) which is undocumented; replace this temporary
override with the documented context manager form: use with torch.device("cpu"):
to run deserialize_to_file(pte_path, data) so you no longer need
prev/save/restore or torch.set_default_device(None); update the block around
deserialize_to_file in test_deep_eval_pt_checkpoint.py to enter the CPU context
(torch.device("cpu")) and call deserialize_to_file inside it, removing the
prev/finally restore logic and any reference to torch.set_default_device(None).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: c0ea83a1-1f69-404b-be2f-4321646d8e13
📒 Files selected for processing (1)
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py
`Backend.detect_backend_by_model` previously hard-coded the `.w`/`.b` vs `.matrix`/`.bias` heuristic and the `"pt-expt"` / `"pt"` target names — backend-specific knowledge leaking into the generic dispatcher. Replace with a generic specificity score: `Backend.match_filename` returns a positive int if the backend claims the file (default = 1 for any matching suffix), and the dispatcher picks the highest. pt_expt overrides `match_filename` to return 2 for `.pt` files whose state-dict uses dpmodel naming, so it out-claims pt's default suffix match for those files. Other backends inherit the default unchanged.
- Import `DEVICE` from `deepmd.pt_expt.utils.env` instead of the legacy `deepmd.pt.utils.env` so the loader uses the pt_expt device policy. - Drop the unnecessary `.clone()` when re-keying tensors during multi-task head selection — `load_state_dict` does not mutate the input dict, so cloning every parameter just inflates memory/time on large multi-task checkpoints. - Replace the cryptic `KeyError` on missing `_extra_state["model_params"]` with an actionable `ValueError` that names the expected structure and points the user at `dp --pt` / `.pte` / `.pt2` alternatives. - Use `shutil.rmtree(..., ignore_errors=True)` for spin-fixture teardown so unexpected leftover files in the temp dir don't fail tests.
GitHub Advanced Security flagged `except Exception: pass` as an empty except with no explanatory comment (CodeQL "Empty except"). Tighten the try-block to only cover `torch.load`, document why a load failure must silently surrender the backend claim (so the dispatcher falls back to the default suffix match for the legacy pt backend), and replace the `pass` with an explicit `return 0`.
CodeRabbit flagged that the non-spin `.pt` tests build their reference tensors at `device=DEVICE` and then call `self.model.forward(...)`, but `_build_model_and_params` left the model on CPU. On CUDA/MPS runners that mismatch would fail before the assertions ran. Move the model to DEVICE in the helper, mirroring `_make_spin_files`.
Summary
dp --pt-expt test -m foo.ptpreviously rejected.ptfiles (only.pt2/.ptewere supported);dp --pt test -m foo.pton a pt_expt-trained checkpoint silently loaded random weights because the dpmodel.w/.bnaming doesn't match the legacy pt backend's.matrix/.bias..pttraining checkpoints first-class for inference under the pt_expt backend.Changes
Backend.detect_backend_by_modelsniffs.ptcontent and routes by parameter naming:.w/.b→ pt-expt,.matrix/.bias→ pt. Bogus.ptfalls back to suffix dispatch (pt). Backwards compatible with all existing pt-trained.ptcheckpoints.pt_expt.DeepEval._load_ptreconstructs the model from_extra_state[\"model_params\"], loads the state-dict viaModelWrapper, and exposes an eagerforward_common_lowerrunner with the same signature as the AOTI/exported module so the existingeval()path is unchanged. Spin-aware (7-arg) and non-spin (6-arg) variants. Multi-task.ptselects a head and remaps keys. Populatesmetadata(default_fparam, dim_fparam/aparam, …) so eval helpers behave the same as the.pt2/.ptepath.pt_expt.get_modellearnsget_spin_model(mirrors dpmodel) so spin checkpoints can be reconstructed frommodel_params(previously it silently returned a non-spinEnergyModel)..pt2/.pte/.ptand raises an actionableValueErrorfor anything else (was: implicit fallthrough to.pteloader → cryptic torch error).Tests
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py(21 tests):.ptroutes to pt-expt; pt-style.ptroutes to pt; bogus.ptfalls back to suffix..pt— metadata accessors,DeepPot(.pt).eval(...)parity vs direct forward at 1e-10,.pthrejection..pt— head selection parity, missing-head error, no-default-no-head error..pt— metadata flags, eager-reference parity, missing-spin-arg error..pt— default fparam matches explicit; varying fparam changes output..pt— aparam takes effect; missing-aparam raises..pt— each head matches its own eager reference; distinct heads produce distinct outputs..pt↔.pteconsistency at 1e-10 for vanilla spin (atomic=True), default fparam (atomic=True), and aparam (atomic=True).Test plan
pytest source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py -vpytest source/tests/pt_expt/infer/ -v(regression: existing.pt2/.ptepaths)dp --pt-expt train, thendp --pt-expt test -m model.ckpt-N.ptproduces identical metrics todp --pt-expt test -m frozen.pt2Known limitations
_load_pthandles such checkpoints, but a user can't currently produce one viadp --pt-expt train. Tests construct them synthetically._load_pt'sexported_moduleis a Python closure (eager), not a realtorch.nn.Module. Sufficient fordp test, buteval_descriptor/eval_typeebd/eval_fitting_last_layerwon't work from a.pt(only from.pt2/.pte)..pt↔.pteconsistency not separately asserted (same eager code path as PBC).Summary by CodeRabbit
New Features
Bug Fixes
Tests