perf(pt_expt): share compiled forward_lower across multi-task shared-fitting tasks#5457
perf(pt_expt): share compiled forward_lower across multi-task shared-fitting tasks#5457anyangml wants to merge 13 commits into
Conversation
for more information, see https://pre-commit.ci
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR promotes per-task fitting/atomic-model buffers into FX symbolic inputs, groups tasks by a structure key to reuse compiled forward_lower graphs, updates the compiled wrapper to supply task buffers at runtime, filters task-buffer keys when loading .pt checkpoints, and coerces loss/metric tensors to floats for aggregation. ChangesTask-structure-aware torch.compile optimization
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR adjusts the pt_expt torch.compile path for multi-task training to reduce redundant compiled graphs (and associated memory/oom issues) by promoting per-task fitting-net buffers to explicit compiled-graph inputs and reusing compiled graphs across tasks when the model structure is shared.
Changes:
- Promote task-specific fitting-net buffers (
bias_atom_e,case_embd) into FX placeholders so one compiled graph can be reused with different per-task buffer values. - Add per-structure caching in the compile pipeline to avoid recompiling the same shared structure for each task.
- Make training-time logging robust by converting tensor scalars to Python floats before formatting/aggregation.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
deepmd/pt_expt/train/training.py |
Adds task-buffer promotion + compiled-graph reuse caching for multi-task compile; adjusts logging scalar handling. |
deepmd/pt_expt/infer/deep_eval.py |
Updates .pt checkpoint loading to ignore newly introduced _CompiledModel per-task buffer copies. |
Comments suppressed due to low confidence (1)
deepmd/pt_expt/train/training.py:1072
- There are existing pt_expt tests covering multi-task + torch.compile, but the new compiled-graph reuse path should be covered by a test that exercises a config where only some components are shared (e.g., fitting_net shared via shared_dict, descriptor not shared). That case would validate the structure-key logic and prevent accidental graph reuse across non-identical forward_lower graphs.
descriptor = model.get_descriptor()
if isinstance(descriptor, DescrptDPA1DP):
n_attn = descriptor.get_numb_attn_layer()
if n_attn > 0:
log.warning(
"Compiling DPA1/se_atten_v2 with %d attention "
"layer(s) (task=%s): the compiled forces/grads "
"are slightly hardware-sensitive (multi-thread "
"reduction order), and may not match the eager "
"path bit-for-bit. Use 'enable_compile: false' "
"or 'attn_layer: 0' for fully reproducible runs.",
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
🧹 Nitpick comments (3)
deepmd/pt_expt/train/training.py (3)
319-319: ⚡ Quick winAdd explicit
strict=Trueto zip call.The
zip()on line 319 iterates overtask_buf_orderandtask_buf_vals, which are guaranteed to have the same length by construction (lines 290-293). Addingstrict=Truedocuments this invariant and provides a runtime assertion if the construction logic ever changes.- for name, val in zip(task_buf_order, task_buf_vals): + for name, val in zip(task_buf_order, task_buf_vals, strict=True):As per coding guidelines, run
ruff check .before committing to catch linting issues.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt_expt/train/training.py` at line 319, The zip over task_buf_order and task_buf_vals in the loop using "for name, val in zip(task_buf_order, task_buf_vals):" should assert the equal-length invariant by adding strict=True; update that zip call to zip(task_buf_order, task_buf_vals, strict=True) so a runtime error surfaces if lengths diverge, then run ruff check . before committing to ensure linting passes.
314-334: ⚡ Quick winPotential issue with buffer restoration logic.
Lines 320 and 334 save and restore buffer entries, but if
originals[name]isNone(buffer didn't exist), line 334 sets_fitting._buffers[name] = Noneinstead of deleting the entry. This could leaveNoneentries in the buffer registry that weren't present before patching.Consider using conditional restoration:
for name, orig in originals.items(): if orig is not None: _fitting._buffers[name] = orig else: _fitting._buffers.pop(name, None)However, if the buffers are guaranteed to exist (since
_get_task_buffersonly extracts existing buffers), this may not be an issue in practice.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt_expt/train/training.py` around lines 314 - 334, The restoration currently writes None back into _fitting._buffers for entries that did not exist before, so change the finally-block that iterates originals (the dict populated from task_buf_order/task_buf_vals) to restore by reassigning when orig is not None and otherwise remove the key (e.g., pop) from _fitting._buffers; locate the dictionary named originals and the finally block that resets _fitting._buffers and replace the unconditional assignment with a conditional restore/remove to avoid leaving None entries after model.forward_lower returns.
92-108: ⚡ Quick winClarify the child name check logic.
Line 103 compares child module names against
_TASK_SPECIFIC_BUFFER_NAMES, which contains buffer names ("bias_atom_e","case_embd"). Child modules fromnamed_children()typically have names like"nets","layers", etc., not buffer names. This check will almost always beTrue, making it effectively a no-op.If the intent is to skip the first child when it's task-specific, the logic may need adjustment. Otherwise, consider removing the check or adding a comment explaining why it's safe to use the first child's
id()directly.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt_expt/train/training.py` around lines 92 - 108, The code in _get_model_structure_key uses named_children() names compared against _TASK_SPECIFIC_BUFFER_NAMES (which lists buffer names like "bias_atom_e"), but child module names come from named_children() and won't match buffer names, so the filter is effectively a no-op; fix by computing the set of task-specific buffer names from the fitting net (e.g., buffers = {n for n,_ in fitting.named_buffers()}) and then skip any child whose name appears in that buffer set (replace the current name check), or if the original intent was to just take the first non-task-specific child drop the faulty comparison and simply return id of the first child from fitting.named_children(); update _get_model_structure_key accordingly and keep reference to fitting, named_children(), named_buffers(), and _TASK_SPECIFIC_BUFFER_NAMES to locate the code to change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 319: The zip over task_buf_order and task_buf_vals in the loop using "for
name, val in zip(task_buf_order, task_buf_vals):" should assert the equal-length
invariant by adding strict=True; update that zip call to zip(task_buf_order,
task_buf_vals, strict=True) so a runtime error surfaces if lengths diverge, then
run ruff check . before committing to ensure linting passes.
- Around line 314-334: The restoration currently writes None back into
_fitting._buffers for entries that did not exist before, so change the
finally-block that iterates originals (the dict populated from
task_buf_order/task_buf_vals) to restore by reassigning when orig is not None
and otherwise remove the key (e.g., pop) from _fitting._buffers; locate the
dictionary named originals and the finally block that resets _fitting._buffers
and replace the unconditional assignment with a conditional restore/remove to
avoid leaving None entries after model.forward_lower returns.
- Around line 92-108: The code in _get_model_structure_key uses named_children()
names compared against _TASK_SPECIFIC_BUFFER_NAMES (which lists buffer names
like "bias_atom_e"), but child module names come from named_children() and won't
match buffer names, so the filter is effectively a no-op; fix by computing the
set of task-specific buffer names from the fitting net (e.g., buffers = {n for
n,_ in fitting.named_buffers()}) and then skip any child whose name appears in
that buffer set (replace the current name check), or if the original intent was
to just take the first non-task-specific child drop the faulty comparison and
simply return id of the first child from fitting.named_children(); update
_get_model_structure_key accordingly and keep reference to fitting,
named_children(), named_buffers(), and _TASK_SPECIFIC_BUFFER_NAMES to locate the
code to change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: bac28a64-1300-4bb0-9d04-cf61734721ce
📒 Files selected for processing (2)
deepmd/pt_expt/infer/deep_eval.pydeepmd/pt_expt/train/training.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5457 +/- ##
========================================
Coverage 81.34% 81.34%
========================================
Files 868 868
Lines 96358 96504 +146
Branches 4233 4235 +2
========================================
+ Hits 78383 78504 +121
- Misses 16675 16700 +25
Partials 1300 1300 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 491-500: If self._task_buf_order is set but accessing buffers
fails, don't silently set task_buf_vals = (); instead, in the except
AttributeError branch for original_model.get_fitting_net()/getattr(...) raise a
clear RuntimeError (or ValueError) that mentions the missing fitting net or
buffer names and refers to self._task_buf_order so callers know why
compiled_forward_lower would fail; keep the existing else path that sets
task_buf_vals = () only when _task_buf_order is empty, and ensure the raised
message names the expected buffer attributes and the method
original_model.get_fitting_net to help locate the root cause.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: add25e51-694f-4a58-8e05-9d3e0d5909e9
📒 Files selected for processing (1)
deepmd/pt_expt/train/training.py
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 343-345: The zip call iterating over task_buf_order and
task_buf_vals should be made strict to satisfy Ruff B905: change
zip(task_buf_order, task_buf_vals) to zip(task_buf_order, task_buf_vals,
strict=True) in the block that checks name.startswith(_AM_PREFIX). Also remove
or rename the unused variable model_pred (found as model_pred in this file) to
_model_pred (or delete it) to resolve RUF059 so no unused binding remains.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 62caa6e7-1434-49b6-8f30-bfab6dd88904
📒 Files selected for processing (1)
deepmd/pt_expt/train/training.py
njzjz-bot
left a comment
There was a problem hiding this comment.
I think there is still a correctness issue in the compiled graph cache key for multitask models.
_get_model_structure_key() currently uses only the id of the first non-task-specific child under fitting_net as the structure key. That can incorrectly treat two task models as sharing the same compiled forward_lower graph when they share only the fitting net but have different/non-shared descriptors or other atomic-model state.
forward_lower is not just the fitting net path: it also includes descriptor/atomic-model computation. When the compiled graph traced for task/model 1 is reused for task/model 2, _CompiledModel.forward() only supplies the extra task buffers (bias, case_embd, out_bias, out_std, etc.). It does not pass descriptor or other atomic-model parameters as dynamic inputs. So if the descriptor is not actually shared, or has different parameters/outputs, task/model 2 may run the graph captured from task/model 1, producing incorrect predictions and gradients.
I would suggest making the cache key include the identity/structure of all shared components that participate in forward_lower — at least the descriptor/atomic-model path in addition to fitting — or, more conservatively, only reusing a compiled graph when both descriptor and fitting components are confirmed to be the same shared objects. Otherwise each task should compile separately.
A regression test would be helpful: construct two tasks that share fitting_net but use distinguishable non-shared descriptors, enable compile, and verify both tasks match the uncompiled results/gradients.
Minor follow-up: _CompiledModel.forward() currently swallows AttributeError while collecting task buffers and then passes an empty tuple, which will likely fail later with a less clear compiled-argument mismatch. It would be better to raise a direct RuntimeError with the missing task-buffer name/context.
Authored by OpenClaw (version: 2026.5.12 f066dd2, model: custom-chat-jinzhezeng-group/gpt-5.5)
Is that a valid use case? |
|
Yes, I think it is a valid use case unless the model configuration or My concern is not limited to the current test setup. The compiled callable is So I would either:
If partial sharing of fitting without descriptor is impossible by construction, then an assertion/check documenting that invariant would also resolve this concern. Otherwise I think compiling separately is safer than silently reusing a graph across non-equivalent task models. Authored by OpenClaw (version: 2026.5.12 f066dd2, model: custom-chat-jinzhezeng-group/gpt-5.5) |
wanghan-iapcm
left a comment
There was a problem hiding this comment.
Code review
Found 2 issues:
- Cache key under-specifies what is baked into the compiled graph.
_get_model_structure_keyreturnsid()of the fitting net's first non-task-specific child, but onlybias_atom_e/case_embd/out_bias/out_stdare promoted to FX placeholders — descriptor parameters/buffers (attention weights, type-embeddingdavg/dstd, exclude-mask, etc.) remain baked-in constants in the traced graph. Two tasks that sharefitting_netbut have different descriptors (or differ inntypes/dim_case_embd/sel/rcut) produce the same structure key and silently reuse task 0's compiled graph, yielding wrong predictions and gradients. Same concern previously raised by njzjz-bot and Copilot; unresolved on the current head.
deepmd-kit/deepmd/pt_expt/train/training.py
Lines 109 to 130 in 9ce8d3e
deepmd-kit/deepmd/pt_expt/train/training.py
Lines 1126 to 1188 in 9ce8d3e
Suggested fix: include the descriptor (and any other non-fitting components participating in forward_lower) in the key, e.g. (id(model.get_descriptor()), id(fitting_first_child)), or build the key from tuple(id(p) for _, p in model.named_parameters()) + tuple(id(b) for n, b in model.named_buffers() if n.rsplit(".", 1)[-1] not in _TASK_SPECIFIC_BUFFER_NAMES + _ATOMIC_MODEL_TASK_BUFFER_NAMES). A regression test where two tasks share fitting but have distinguishable descriptors and assert compiled outputs match eager would catch this.
- The promoted-buffer set is incomplete. Only
bias_atom_e,case_embd,out_bias,out_stdare promoted to FX placeholders, but fitting nets carry other per-task statistics buffers that are silently baked in as task-0 constants — notablyfparam_avg/fparam_inv_stdandaparam_avg/aparam_inv_std(set per task bymake_stat_inputfrom each task's data distribution), and the descriptor'sEnvMatdavg/dstdif descriptor stats also vary per task. For multi-task configs that usefparam/aparamwith shared fitting but task-local stats, all tasks would end up running with task 0's normalization, producing incorrect inputs to the fitting MLP.
deepmd-kit/deepmd/pt_expt/train/training.py
Lines 72 to 86 in 9ce8d3e
Suggested fix: either (a) expand _TASK_SPECIFIC_BUFFER_NAMES to include fparam_avg, fparam_inv_std, aparam_avg, aparam_inv_std and have _get_task_buffers enumerate them, or (b) auto-detect per-task buffers by diffing buffer identities across tasks in _compile_model rather than maintaining a hardcoded allow-list that will drift as descriptors and fittings evolve.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 1145: The for-loop uses an unused loop variable `sk` causing Ruff B007;
change the iteration to loop over the group values directly by replacing `for
sk, group_keys in _groups.items():` with iterating `_groups.values()` (i.e. `for
group_keys in _groups.values():`) so `sk` is removed and the loop body continues
to use `group_keys` unchanged.
In `@source/tests/pt_expt/test_training.py`:
- Around line 1301-1527: The test
TestCompiledSharedFittingDifferentDescriptor.test_compiled_matches_eager_per_task
is unbounded and needs the repo-standard 60s timeout; add a 60s test timeout
(e.g. annotate the test method or the test class with pytest.mark.timeout(60)
and import pytest) so the compiled/eager training regression cannot hang CI
longer than allowed, ensuring the decorator is applied to
test_compiled_matches_eager_per_task (or the containing
TestCompiledSharedFittingDifferentDescriptor class).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 33c2d228-ada9-4027-b0f6-2a1f979e9ce9
📒 Files selected for processing (2)
deepmd/pt_expt/train/training.pysource/tests/pt_expt/test_training.py
njzjz-bot
left a comment
There was a problem hiding this comment.
Re-reviewed the current head (8c25a2289). The main correctness concern from the previous review looks addressed:
_get_model_structure_key()now includes the descriptor identity, so shared fitting with different descriptors no longer collides.- fitting-net task buffers are auto-detected by buffer identity within each reusable structure group, and
out_bias/out_stdare promoted fromatomic_model. - the new regression test covers the shared-fitting / different-descriptor case.
- required checks are green.
I do not see a remaining correctness blocker from the original cache-key issue. The only things I would still clean up before merge are non-blocking maintenance items already pointed out by CodeRabbit: the unused sk loop variable, adding the timeout mark to the new compile regression test, and making the missing promoted-buffer path in _CompiledModel.forward() fail with a direct RuntimeError rather than falling through to an argument mismatch.
— OpenClaw 2026.5.12 (f066dd2) (model: custom-chat-jinzhezeng-group/gpt-5.5)
njzjz-bot
left a comment
There was a problem hiding this comment.
Re-reviewed current head 7499321b07e0f3b650a4d18f76bc373bb31a7e61. The previous fixes are mostly good, and the required checks are green, but I think the latest cleanup reintroduces a cache-key under-specification for partially shared descriptors.
_get_model_structure_key() now uses the first descriptor parameter tensor id instead of the descriptor module id. That is safe only if sharing one descriptor parameter implies the whole descriptor path participating in forward_lower is equivalent. However several pt_expt descriptors support partial descriptor sharing, e.g. shared_level == 1 shares only type_embedding for DPA1/DPA2/DPA3/SE_T_TEBD while leaving the main descriptor block (se_atten, repinit/repformers, repflows, se_ttebd, etc.) task-local. In that case the first desc.named_parameters() entry may come from the shared type embedding, so two tasks can get the same (descriptor_id, fitting_child_id) key even though the rest of descriptor computation is different and still baked into the traced compiled graph.
So the key should not collapse on just the first shared descriptor parameter. Safer options:
- include identities of all descriptor parameters/buffers/modules that are not promoted to FX placeholders; or
- use a conservative descriptor identity such as
id(model.get_descriptor())unless the whole descriptor object/path is known to be shared; or - explicitly encode the descriptor sharing level/components in the cache key.
A regression test for shared fitting_net + descriptor shared at level 1 (shared type embedding only, distinguishable descriptor block) would catch this. The current different-descriptor test covers the non-shared descriptor case, but not this partial-sharing collision.
Non-blocking note: the docstring/comment above _get_model_structure_key() still says distinct descriptors get distinct keys, which is no longer always true after switching to the first parameter id.
Authored by OpenClaw (version: 2026.5.12 f066dd2, model: custom-chat-jinzhezeng-group/gpt-5.5)
njzjz-bot
left a comment
There was a problem hiding this comment.
Reviewed the current PR version. The main idea looks good to me: sharing the compiled forward_lower by a structure key, while passing per-task buffers (bias_atom_e/case_embd-like fitting buffers and atomic-model out_bias/out_std) as explicit inputs, should address the OOM/NCCL-timeout motivation without baking task-local statistics into the compiled graph. I also like that the PR added a regression test for shared fitting with different descriptors, and all CI checks are green.
One minor inconsistency I noticed: the comment says partial descriptor sharing “raises an explicit error”, but the implementation currently logs a warning and continues. I think that is acceptable if we intentionally keep this as a non-blocking unsupported-combination warning; otherwise the comment should be adjusted.
Given the passing CI and the regression coverage, I’m okay with merging this.
— OpenClaw unknown (model: custom-chat-jinzhezeng-group/gpt-5.5)
make dataset embedding and energy bias as input not buffer for compile, this allows multitask training share compiled model thus resolve OOM and NCCL timeout issue. Since the empty_cache and del are removed, no GC complaints.
Regression Test

Summary by CodeRabbit
New Features
Bug Fixes
Tests