Skip to content

perf(pt_expt): share compiled forward_lower across multi-task shared-fitting tasks#5457

Open
anyangml wants to merge 13 commits into
deepmodeling:masterfrom
anyangml:fix/compile-multitask
Open

perf(pt_expt): share compiled forward_lower across multi-task shared-fitting tasks#5457
anyangml wants to merge 13 commits into
deepmodeling:masterfrom
anyangml:fix/compile-multitask

Conversation

@anyangml
Copy link
Copy Markdown
Collaborator

@anyangml anyangml commented May 26, 2026

make dataset embedding and energy bias as input not buffer for compile, this allows multitask training share compiled model thus resolve OOM and NCCL timeout issue. Since the empty_cache and del are removed, no GC complaints.

Regression Test
lcurve

Summary by CodeRabbit

  • New Features

    • Multi-task training groups models by structure and caches/reuses compiled computation graphs, with per-task buffer handling to support shared fitting nets.
  • Bug Fixes

    • Checkpoint loading now skips extraneous per-task buffer entries so only original model parameters are restored.
    • Training aggregation coerces tensor-like loss/metric values to floats for accurate reporting.
  • Tests

    • Added regression test ensuring compiled and eager outputs match per task for shared-fitting, different-descriptor setups.

Review Change Stack

Copilot AI review requested due to automatic review settings May 26, 2026 02:03
@dosubot dosubot Bot added the bug label May 26, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 26, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR promotes per-task fitting/atomic-model buffers into FX symbolic inputs, groups tasks by a structure key to reuse compiled forward_lower graphs, updates the compiled wrapper to supply task buffers at runtime, filters task-buffer keys when loading .pt checkpoints, and coerces loss/metric tensors to floats for aggregation.

Changes

Task-structure-aware torch.compile optimization

Layer / File(s) Summary
Task buffer detection & structure key
deepmd/pt_expt/train/training.py
Adds buffer-name constants, _detect_task_buffers, and updates _get_model_structure_key to return a tuple capturing descriptor identity and fitting-net child identity.
_trace_and_compile buffer promotion
deepmd/pt_expt/train/training.py
Extends _trace_and_compile to accept task_buffers, temporarily patch fitting-net/atomic-model _buffers so promoted buffers become FX placeholders, include them as extra symbolic inputs to make_fx, and return (compiled_module, task_buf_order).
_CompiledModel runtime ordering & forward
deepmd/pt_expt/train/training.py
_CompiledModel.__init__ accepts task_buf_order; forward() fetches current-task buffer tensors (atomic-model am/ prefixed and fitting-net buffers) and passes them as variadic args into the compiled forward_lower.
Structure-key compilation cache in _compile_model
deepmd/pt_expt/train/training.py
Pre-pass groups tasks by structure_key, detects per-task promoted buffers per group, reuses cached (compiled_lower, task_buf_order) when available or compiles per-structure and caches; constructs _CompiledModel(..., task_buf_order).
Training aggregation float coercion
deepmd/pt_expt/train/training.py
Adds _to_float() and coerces tensor-like loss/metric entries to Python floats during single-task and multi-task train/validation aggregation (excluding l2_ metrics).
Skip task buffer entries during checkpoint cleanup
deepmd/pt_expt/infer/deep_eval.py
DeepEval._load_pt now filters out ._task_ marked keys (in addition to compiled-forward-lower keys) when preparing a checkpoint state_dict for loading.
Regression test: shared fitting, different descriptor
source/tests/pt_expt/test_training.py
Adds TestCompiledSharedFittingDifferentDescriptor to exercise multi-task compilation reuse, structure-key differences, eager→compiled weight sync, and compiled-vs-eager numeric equivalence per task.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5423: Both PRs modify deepmd/pt_expt/infer/deep_eval.py—specifically DeepEval._load_pt's state-dict cleanup/filtered-key loading logic—so the main PR's new "._task_" key omission is directly related.
  • deepmodeling/deepmd-kit#5397: Related multi-task torch.compile/FX compilation work that this PR extends with per-task buffer promotion and structure-key reuse.

Suggested labels

enhancement

Suggested reviewers

  • njzjz
  • iProzd
  • wanghan-iapcm
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 57.14% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main performance optimization: sharing compiled forward_lower graphs across multi-task models with shared fitting layers, which is the primary objective of the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Comment thread deepmd/pt_expt/train/training.py Fixed
Comment thread deepmd/pt_expt/train/training.py Fixed
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adjusts the pt_expt torch.compile path for multi-task training to reduce redundant compiled graphs (and associated memory/oom issues) by promoting per-task fitting-net buffers to explicit compiled-graph inputs and reusing compiled graphs across tasks when the model structure is shared.

Changes:

  • Promote task-specific fitting-net buffers (bias_atom_e, case_embd) into FX placeholders so one compiled graph can be reused with different per-task buffer values.
  • Add per-structure caching in the compile pipeline to avoid recompiling the same shared structure for each task.
  • Make training-time logging robust by converting tensor scalars to Python floats before formatting/aggregation.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
deepmd/pt_expt/train/training.py Adds task-buffer promotion + compiled-graph reuse caching for multi-task compile; adjusts logging scalar handling.
deepmd/pt_expt/infer/deep_eval.py Updates .pt checkpoint loading to ignore newly introduced _CompiledModel per-task buffer copies.
Comments suppressed due to low confidence (1)

deepmd/pt_expt/train/training.py:1072

  • There are existing pt_expt tests covering multi-task + torch.compile, but the new compiled-graph reuse path should be covered by a test that exercises a config where only some components are shared (e.g., fitting_net shared via shared_dict, descriptor not shared). That case would validate the structure-key logic and prevent accidental graph reuse across non-identical forward_lower graphs.
            descriptor = model.get_descriptor()
            if isinstance(descriptor, DescrptDPA1DP):
                n_attn = descriptor.get_numb_attn_layer()
                if n_attn > 0:
                    log.warning(
                        "Compiling DPA1/se_atten_v2 with %d attention "
                        "layer(s) (task=%s): the compiled forces/grads "
                        "are slightly hardware-sensitive (multi-thread "
                        "reduction order), and may not match the eager "
                        "path bit-for-bit.  Use 'enable_compile: false' "
                        "or 'attn_layer: 0' for fully reproducible runs.",

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/pt_expt/train/training.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
deepmd/pt_expt/train/training.py (3)

319-319: ⚡ Quick win

Add explicit strict=True to zip call.

The zip() on line 319 iterates over task_buf_order and task_buf_vals, which are guaranteed to have the same length by construction (lines 290-293). Adding strict=True documents this invariant and provides a runtime assertion if the construction logic ever changes.

-            for name, val in zip(task_buf_order, task_buf_vals):
+            for name, val in zip(task_buf_order, task_buf_vals, strict=True):

As per coding guidelines, run ruff check . before committing to catch linting issues.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` at line 319, The zip over task_buf_order
and task_buf_vals in the loop using "for name, val in zip(task_buf_order,
task_buf_vals):" should assert the equal-length invariant by adding strict=True;
update that zip call to zip(task_buf_order, task_buf_vals, strict=True) so a
runtime error surfaces if lengths diverge, then run ruff check . before
committing to ensure linting passes.

314-334: ⚡ Quick win

Potential issue with buffer restoration logic.

Lines 320 and 334 save and restore buffer entries, but if originals[name] is None (buffer didn't exist), line 334 sets _fitting._buffers[name] = None instead of deleting the entry. This could leave None entries in the buffer registry that weren't present before patching.

Consider using conditional restoration:

for name, orig in originals.items():
    if orig is not None:
        _fitting._buffers[name] = orig
    else:
        _fitting._buffers.pop(name, None)

However, if the buffers are guaranteed to exist (since _get_task_buffers only extracts existing buffers), this may not be an issue in practice.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` around lines 314 - 334, The restoration
currently writes None back into _fitting._buffers for entries that did not exist
before, so change the finally-block that iterates originals (the dict populated
from task_buf_order/task_buf_vals) to restore by reassigning when orig is not
None and otherwise remove the key (e.g., pop) from _fitting._buffers; locate the
dictionary named originals and the finally block that resets _fitting._buffers
and replace the unconditional assignment with a conditional restore/remove to
avoid leaving None entries after model.forward_lower returns.

92-108: ⚡ Quick win

Clarify the child name check logic.

Line 103 compares child module names against _TASK_SPECIFIC_BUFFER_NAMES, which contains buffer names ("bias_atom_e", "case_embd"). Child modules from named_children() typically have names like "nets", "layers", etc., not buffer names. This check will almost always be True, making it effectively a no-op.

If the intent is to skip the first child when it's task-specific, the logic may need adjustment. Otherwise, consider removing the check or adding a comment explaining why it's safe to use the first child's id() directly.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` around lines 92 - 108, The code in
_get_model_structure_key uses named_children() names compared against
_TASK_SPECIFIC_BUFFER_NAMES (which lists buffer names like "bias_atom_e"), but
child module names come from named_children() and won't match buffer names, so
the filter is effectively a no-op; fix by computing the set of task-specific
buffer names from the fitting net (e.g., buffers = {n for n,_ in
fitting.named_buffers()}) and then skip any child whose name appears in that
buffer set (replace the current name check), or if the original intent was to
just take the first non-task-specific child drop the faulty comparison and
simply return id of the first child from fitting.named_children(); update
_get_model_structure_key accordingly and keep reference to fitting,
named_children(), named_buffers(), and _TASK_SPECIFIC_BUFFER_NAMES to locate the
code to change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 319: The zip over task_buf_order and task_buf_vals in the loop using "for
name, val in zip(task_buf_order, task_buf_vals):" should assert the equal-length
invariant by adding strict=True; update that zip call to zip(task_buf_order,
task_buf_vals, strict=True) so a runtime error surfaces if lengths diverge, then
run ruff check . before committing to ensure linting passes.
- Around line 314-334: The restoration currently writes None back into
_fitting._buffers for entries that did not exist before, so change the
finally-block that iterates originals (the dict populated from
task_buf_order/task_buf_vals) to restore by reassigning when orig is not None
and otherwise remove the key (e.g., pop) from _fitting._buffers; locate the
dictionary named originals and the finally block that resets _fitting._buffers
and replace the unconditional assignment with a conditional restore/remove to
avoid leaving None entries after model.forward_lower returns.
- Around line 92-108: The code in _get_model_structure_key uses named_children()
names compared against _TASK_SPECIFIC_BUFFER_NAMES (which lists buffer names
like "bias_atom_e"), but child module names come from named_children() and won't
match buffer names, so the filter is effectively a no-op; fix by computing the
set of task-specific buffer names from the fitting net (e.g., buffers = {n for
n,_ in fitting.named_buffers()}) and then skip any child whose name appears in
that buffer set (replace the current name check), or if the original intent was
to just take the first non-task-specific child drop the faulty comparison and
simply return id of the first child from fitting.named_children(); update
_get_model_structure_key accordingly and keep reference to fitting,
named_children(), named_buffers(), and _TASK_SPECIFIC_BUFFER_NAMES to locate the
code to change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: bac28a64-1300-4bb0-9d04-cf61734721ce

📥 Commits

Reviewing files that changed from the base of the PR and between f39a081 and 4cee0bf.

📒 Files selected for processing (2)
  • deepmd/pt_expt/infer/deep_eval.py
  • deepmd/pt_expt/train/training.py

@codecov
Copy link
Copy Markdown

codecov Bot commented May 26, 2026

Codecov Report

❌ Patch coverage is 83.75000% with 26 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.34%. Comparing base (016141f) to head (c86fe9d).
⚠️ Report is 2 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt_expt/train/training.py 84.07% 25 Missing ⚠️
deepmd/pt_expt/infer/deep_eval.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##           master    #5457    +/-   ##
========================================
  Coverage   81.34%   81.34%            
========================================
  Files         868      868            
  Lines       96358    96504   +146     
  Branches     4233     4235     +2     
========================================
+ Hits        78383    78504   +121     
- Misses      16675    16700    +25     
  Partials     1300     1300            

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 491-500: If self._task_buf_order is set but accessing buffers
fails, don't silently set task_buf_vals = (); instead, in the except
AttributeError branch for original_model.get_fitting_net()/getattr(...) raise a
clear RuntimeError (or ValueError) that mentions the missing fitting net or
buffer names and refers to self._task_buf_order so callers know why
compiled_forward_lower would fail; keep the existing else path that sets
task_buf_vals = () only when _task_buf_order is empty, and ensure the raised
message names the expected buffer attributes and the method
original_model.get_fitting_net to help locate the root cause.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: add25e51-694f-4a58-8e05-9d3e0d5909e9

📥 Commits

Reviewing files that changed from the base of the PR and between 4cee0bf and f3e29fe.

📒 Files selected for processing (1)
  • deepmd/pt_expt/train/training.py

Comment thread deepmd/pt_expt/train/training.py
Comment thread deepmd/pt_expt/train/training.py Fixed
Comment thread deepmd/pt_expt/train/training.py Fixed
@anyangml anyangml requested review from njzjz and wanghan-iapcm May 26, 2026 09:27
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 343-345: The zip call iterating over task_buf_order and
task_buf_vals should be made strict to satisfy Ruff B905: change
zip(task_buf_order, task_buf_vals) to zip(task_buf_order, task_buf_vals,
strict=True) in the block that checks name.startswith(_AM_PREFIX). Also remove
or rename the unused variable model_pred (found as model_pred in this file) to
_model_pred (or delete it) to resolve RUF059 so no unused binding remains.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 62caa6e7-1434-49b6-8f30-bfab6dd88904

📥 Commits

Reviewing files that changed from the base of the PR and between f3e29fe and 9ce8d3e.

📒 Files selected for processing (1)
  • deepmd/pt_expt/train/training.py

Comment thread deepmd/pt_expt/train/training.py
Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is still a correctness issue in the compiled graph cache key for multitask models.

_get_model_structure_key() currently uses only the id of the first non-task-specific child under fitting_net as the structure key. That can incorrectly treat two task models as sharing the same compiled forward_lower graph when they share only the fitting net but have different/non-shared descriptors or other atomic-model state.

forward_lower is not just the fitting net path: it also includes descriptor/atomic-model computation. When the compiled graph traced for task/model 1 is reused for task/model 2, _CompiledModel.forward() only supplies the extra task buffers (bias, case_embd, out_bias, out_std, etc.). It does not pass descriptor or other atomic-model parameters as dynamic inputs. So if the descriptor is not actually shared, or has different parameters/outputs, task/model 2 may run the graph captured from task/model 1, producing incorrect predictions and gradients.

I would suggest making the cache key include the identity/structure of all shared components that participate in forward_lower — at least the descriptor/atomic-model path in addition to fitting — or, more conservatively, only reusing a compiled graph when both descriptor and fitting components are confirmed to be the same shared objects. Otherwise each task should compile separately.

A regression test would be helpful: construct two tasks that share fitting_net but use distinguishable non-shared descriptors, enable compile, and verify both tasks match the uncompiled results/gradients.

Minor follow-up: _CompiledModel.forward() currently swallows AttributeError while collecting task buffers and then passes an empty tuple, which will likely fail later with a less clear compiled-argument mismatch. It would be better to raise a direct RuntimeError with the missing task-buffer name/context.

Authored by OpenClaw (version: 2026.5.12 f066dd2, model: custom-chat-jinzhezeng-group/gpt-5.5)

@anyangml
Copy link
Copy Markdown
Collaborator Author

I think there is still a correctness issue in the compiled graph cache key for multitask models.

_get_model_structure_key() currently uses only the id of the first non-task-specific child under fitting_net as the structure key. That can incorrectly treat two task models as sharing the same compiled forward_lower graph when they share only the fitting net but have different/non-shared descriptors or other atomic-model state.

forward_lower is not just the fitting net path: it also includes descriptor/atomic-model computation. When the compiled graph traced for task/model 1 is reused for task/model 2, _CompiledModel.forward() only supplies the extra task buffers (bias, case_embd, out_bias, out_std, etc.). It does not pass descriptor or other atomic-model parameters as dynamic inputs. So if the descriptor is not actually shared, or has different parameters/outputs, task/model 2 may run the graph captured from task/model 1, producing incorrect predictions and gradients.

I would suggest making the cache key include the identity/structure of all shared components that participate in forward_lower — at least the descriptor/atomic-model path in addition to fitting — or, more conservatively, only reusing a compiled graph when both descriptor and fitting components are confirmed to be the same shared objects. Otherwise each task should compile separately.

A regression test would be helpful: construct two tasks that share fitting_net but use distinguishable non-shared descriptors, enable compile, and verify both tasks match the uncompiled results/gradients.

Minor follow-up: _CompiledModel.forward() currently swallows AttributeError while collecting task buffers and then passes an empty tuple, which will likely fail later with a less clear compiled-argument mismatch. It would be better to raise a direct RuntimeError with the missing task-buffer name/context.

Authored by OpenClaw (version: 2026.5.12 f066dd2, model: custom-chat-jinzhezeng-group/gpt-5.5)

Is that a valid use case?

@njzjz-bot
Copy link
Copy Markdown
Contributor

Yes, I think it is a valid use case unless the model configuration or share_params logic explicitly forbids it.

My concern is not limited to the current test setup. The compiled callable is forward_lower, so the traced graph covers the descriptor/atomic-model path before the fitting net as well. If two task models share only fitting_net but keep different descriptors (or any other non-shared state participating in forward_lower), a cache key based only on a fitting-net child can collide. In that case task 2 would reuse the graph captured from task 1, while _CompiledModel.forward() only varies the promoted task buffers; it does not make descriptor parameters/buffers dynamic inputs.

So I would either:

  1. make the reuse condition conservative: reuse only when the whole forward_lower-participating model structure is shared, e.g. descriptor/atomic model and fitting are the same shared objects; or
  2. include the identities of all relevant parameters/buffers/modules in the cache key, excluding only the intentionally promoted per-task buffers.

If partial sharing of fitting without descriptor is impossible by construction, then an assertion/check documenting that invariant would also resolve this concern. Otherwise I think compiling separately is safer than silently reusing a graph across non-equivalent task models.

Authored by OpenClaw (version: 2026.5.12 f066dd2, model: custom-chat-jinzhezeng-group/gpt-5.5)

Copy link
Copy Markdown
Collaborator

@wanghan-iapcm wanghan-iapcm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code review

Found 2 issues:

  1. Cache key under-specifies what is baked into the compiled graph. _get_model_structure_key returns id() of the fitting net's first non-task-specific child, but only bias_atom_e/case_embd/out_bias/out_std are promoted to FX placeholders — descriptor parameters/buffers (attention weights, type-embedding davg/dstd, exclude-mask, etc.) remain baked-in constants in the traced graph. Two tasks that share fitting_net but have different descriptors (or differ in ntypes / dim_case_embd / sel / rcut) produce the same structure key and silently reuse task 0's compiled graph, yielding wrong predictions and gradients. Same concern previously raised by njzjz-bot and Copilot; unresolved on the current head.

def _get_model_structure_key(model: torch.nn.Module) -> int:
"""Return an id that is identical for all tasks that share a fitting net.
After ``share_params``, the fitting net's child sub-modules are literally
the same Python objects across tasks. The first non-task-specific child's
``id()`` is therefore the same for all shared tasks and unique across
unrelated models.
"""
try:
fitting = model.get_fitting_net()
for name, child in fitting.named_children():
if name not in _TASK_SPECIFIC_BUFFER_NAMES:
return id(child)
except AttributeError:
pass
return id(model)
# ---------------------------------------------------------------------------
# Helper: loss factory (reused from pt)

structure_key = _get_model_structure_key(model)
task_bufs = _get_task_buffers(model)
if structure_key in _compiled_by_structure:
# Shared structure: reuse the already-compiled graph.
compiled_lower, task_buf_order = _compiled_by_structure[structure_key]
log.info(
"Reusing compiled graph for task=%s (shared model structure).",
task_key,
)
else:
inp, _ = self.get_data(is_train=True, task_key=task_key)
coord = inp["coord"].detach()
atype = inp["atype"].detach()
box = inp.get("box")
if box is not None:
box = box.detach()
nframes, nloc = atype.shape[:2]
coord_3d = coord.reshape(nframes, nloc, 3)
box_flat = box.reshape(nframes, 9) if box is not None else None
if box_flat is not None:
coord_norm = normalize_coord(
coord_3d, box_flat.reshape(nframes, 3, 3)
)
else:
coord_norm = coord_3d
ext_coord, ext_atype, mapping = extend_coord_with_ghosts(
coord_norm, atype, box_flat, model.get_rcut()
)
nlist_t = build_neighbor_list(
ext_coord,
ext_atype,
nloc,
model.get_rcut(),
model.get_sel(),
distinguish_types=False,
)
ext_coord = ext_coord.reshape(nframes, -1, 3)
fparam = inp.get("fparam")
aparam = inp.get("aparam")
charge_spin = inp.get("charge_spin")
compiled_lower, task_buf_order = _trace_and_compile(
model,
ext_coord,
ext_atype,
nlist_t,
mapping,
fparam,
aparam,
charge_spin=charge_spin,
task_buffers=task_bufs if task_bufs else None,
compile_opts=compile_opts,
)
_compiled_by_structure[structure_key] = (compiled_lower, task_buf_order)
wrapper_mod.model[task_key] = _CompiledModel(
model, compiled_lower, task_buf_order, task_bufs

Suggested fix: include the descriptor (and any other non-fitting components participating in forward_lower) in the key, e.g. (id(model.get_descriptor()), id(fitting_first_child)), or build the key from tuple(id(p) for _, p in model.named_parameters()) + tuple(id(b) for n, b in model.named_buffers() if n.rsplit(".", 1)[-1] not in _TASK_SPECIFIC_BUFFER_NAMES + _ATOMIC_MODEL_TASK_BUFFER_NAMES). A regression test where two tasks share fitting but have distinguishable descriptors and assert compiled outputs match eager would catch this.

  1. The promoted-buffer set is incomplete. Only bias_atom_e, case_embd, out_bias, out_std are promoted to FX placeholders, but fitting nets carry other per-task statistics buffers that are silently baked in as task-0 constants — notably fparam_avg/fparam_inv_std and aparam_avg/aparam_inv_std (set per task by make_stat_input from each task's data distribution), and the descriptor's EnvMat davg/dstd if descriptor stats also vary per task. For multi-task configs that use fparam/aparam with shared fitting but task-local stats, all tasks would end up running with task 0's normalization, producing incorrect inputs to the fitting MLP.

# Buffer names in the fitting net that differ per task after share_params;
# everything else in the fitting net is the same Python object across tasks.
_TASK_SPECIFIC_BUFFER_NAMES: tuple[str, ...] = ("bias_atom_e", "case_embd")
# Buffer names in atomic_model that are per-task (energy/output statistics).
# These live one level above the fitting net and are not reached by
# fitting-net share_params, so they must also be promoted to FX placeholders.
_ATOMIC_MODEL_TASK_BUFFER_NAMES: tuple[str, ...] = ("out_bias", "out_std")
# Prefix used in task_buf_order keys to distinguish atomic_model buffers
# from fitting-net buffers.
_AM_PREFIX = "am/"

Suggested fix: either (a) expand _TASK_SPECIFIC_BUFFER_NAMES to include fparam_avg, fparam_inv_std, aparam_avg, aparam_inv_std and have _get_task_buffers enumerate them, or (b) auto-detect per-task buffers by diffing buffer identities across tasks in _compile_model rather than maintaining a hardcoded allow-list that will drift as descriptors and fittings evolve.

@wanghan-iapcm wanghan-iapcm changed the title Fix: compile multitask perf(pt_expt): share compiled forward_lower across multi-task shared-fitting tasks May 27, 2026
Comment thread deepmd/pt_expt/train/training.py Dismissed
Comment thread deepmd/pt_expt/train/training.py Dismissed
Comment thread deepmd/pt_expt/train/training.py Dismissed
Comment thread deepmd/pt_expt/train/training.py Fixed
Comment thread deepmd/pt_expt/train/training.py Fixed
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 1145: The for-loop uses an unused loop variable `sk` causing Ruff B007;
change the iteration to loop over the group values directly by replacing `for
sk, group_keys in _groups.items():` with iterating `_groups.values()` (i.e. `for
group_keys in _groups.values():`) so `sk` is removed and the loop body continues
to use `group_keys` unchanged.

In `@source/tests/pt_expt/test_training.py`:
- Around line 1301-1527: The test
TestCompiledSharedFittingDifferentDescriptor.test_compiled_matches_eager_per_task
is unbounded and needs the repo-standard 60s timeout; add a 60s test timeout
(e.g. annotate the test method or the test class with pytest.mark.timeout(60)
and import pytest) so the compiled/eager training regression cannot hang CI
longer than allowed, ensuring the decorator is applied to
test_compiled_matches_eager_per_task (or the containing
TestCompiledSharedFittingDifferentDescriptor class).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 33c2d228-ada9-4027-b0f6-2a1f979e9ce9

📥 Commits

Reviewing files that changed from the base of the PR and between 9ce8d3e and 8c25a22.

📒 Files selected for processing (2)
  • deepmd/pt_expt/train/training.py
  • source/tests/pt_expt/test_training.py

Comment thread deepmd/pt_expt/train/training.py Outdated
Comment thread source/tests/pt_expt/test_training.py
Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-reviewed the current head (8c25a2289). The main correctness concern from the previous review looks addressed:

  • _get_model_structure_key() now includes the descriptor identity, so shared fitting with different descriptors no longer collides.
  • fitting-net task buffers are auto-detected by buffer identity within each reusable structure group, and out_bias/out_std are promoted from atomic_model.
  • the new regression test covers the shared-fitting / different-descriptor case.
  • required checks are green.

I do not see a remaining correctness blocker from the original cache-key issue. The only things I would still clean up before merge are non-blocking maintenance items already pointed out by CodeRabbit: the unused sk loop variable, adding the timeout mark to the new compile regression test, and making the missing promoted-buffer path in _CompiledModel.forward() fail with a direct RuntimeError rather than falling through to an argument mismatch.

— OpenClaw 2026.5.12 (f066dd2) (model: custom-chat-jinzhezeng-group/gpt-5.5)

Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-reviewed current head 7499321b07e0f3b650a4d18f76bc373bb31a7e61. The previous fixes are mostly good, and the required checks are green, but I think the latest cleanup reintroduces a cache-key under-specification for partially shared descriptors.

_get_model_structure_key() now uses the first descriptor parameter tensor id instead of the descriptor module id. That is safe only if sharing one descriptor parameter implies the whole descriptor path participating in forward_lower is equivalent. However several pt_expt descriptors support partial descriptor sharing, e.g. shared_level == 1 shares only type_embedding for DPA1/DPA2/DPA3/SE_T_TEBD while leaving the main descriptor block (se_atten, repinit/repformers, repflows, se_ttebd, etc.) task-local. In that case the first desc.named_parameters() entry may come from the shared type embedding, so two tasks can get the same (descriptor_id, fitting_child_id) key even though the rest of descriptor computation is different and still baked into the traced compiled graph.

So the key should not collapse on just the first shared descriptor parameter. Safer options:

  1. include identities of all descriptor parameters/buffers/modules that are not promoted to FX placeholders; or
  2. use a conservative descriptor identity such as id(model.get_descriptor()) unless the whole descriptor object/path is known to be shared; or
  3. explicitly encode the descriptor sharing level/components in the cache key.

A regression test for shared fitting_net + descriptor shared at level 1 (shared type embedding only, distinguishable descriptor block) would catch this. The current different-descriptor test covers the non-shared descriptor case, but not this partial-sharing collision.

Non-blocking note: the docstring/comment above _get_model_structure_key() still says distinct descriptors get distinct keys, which is no longer always true after switching to the first parameter id.

Authored by OpenClaw (version: 2026.5.12 f066dd2, model: custom-chat-jinzhezeng-group/gpt-5.5)

Comment thread deepmd/pt_expt/train/training.py Fixed
@anyangml anyangml requested a review from njzjz-bot June 1, 2026 09:36
Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the current PR version. The main idea looks good to me: sharing the compiled forward_lower by a structure key, while passing per-task buffers (bias_atom_e/case_embd-like fitting buffers and atomic-model out_bias/out_std) as explicit inputs, should address the OOM/NCCL-timeout motivation without baking task-local statistics into the compiled graph. I also like that the PR added a regression test for shared fitting with different descriptors, and all CI checks are green.

One minor inconsistency I noticed: the comment says partial descriptor sharing “raises an explicit error”, but the implementation currently logs a warning and continues. I think that is acceptable if we intentionally keep this as a non-blocking unsupported-combination warning; otherwise the comment should be adjusted.

Given the passing CI and the regression coverage, I’m okay with merging this.

— OpenClaw unknown (model: custom-chat-jinzhezeng-group/gpt-5.5)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants