Skip to content

feat(pt): add ema shadow model#5420

Open
OutisLi wants to merge 5 commits intodeepmodeling:masterfrom
OutisLi:pr/ema
Open

feat(pt): add ema shadow model#5420
OutisLi wants to merge 5 commits intodeepmodeling:masterfrom
OutisLi:pr/ema

Conversation

@OutisLi
Copy link
Copy Markdown
Collaborator

@OutisLi OutisLi commented Apr 25, 2026

Summary by CodeRabbit

  • New Features

    • Added Exponential Moving Average (EMA) support: maintain EMA shadows, emit EMA-only checkpoints, resume/restore EMA state, and evaluate using EMA weights for dedicated validation and best-model selection.
  • Configuration

    • New options to enable EMA, set decay and checkpoint retention, and opt into EMA full-validation.
  • Validation

    • Full-validation made configurable (log path, prefixes, state keys, eval context) and supports separate EMA evaluation/logs.
  • Tests

    • Added EMA-focused tests for checkpointing, resume/restore, config validation, and full-validation artifact separation.
  • Chores

    • Checkpoint saving/retention refactored for deterministic pruning and consistent EMA checkpoint handling.

Copilot AI review requested due to automatic review settings April 25, 2026 02:25
@dosubot dosubot Bot added the new feature label Apr 25, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an exponential moving average (EMA) “shadow” model to the PyTorch training stack, including EMA checkpointing, restart support, and optional EMA-based full validation outputs.

Changes:

  • Introduce ModelEMA with serialization, shadow-weight application, and filename helpers for EMA checkpoints/logs.
  • Extend the PyTorch trainer to maintain/update EMA state, save periodic EMA checkpoints, and optionally run full validation on EMA weights with separate logs/best checkpoints.
  • Update argcheck schema/validation and add unit tests covering EMA rotation, restart restore, and EMA full-validation outputs.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
deepmd/pt/train/ema.py New EMA implementation (shadow params, state dict, apply-shadow context).
deepmd/pt/train/training.py Trainer integration: EMA lifecycle, EMA validation, checkpoint format/rotation refactor.
deepmd/pt/train/validation.py Generalize FullValidator to support alternate state stores, prefixes, log paths, and eval contexts (used by EMA validation).
deepmd/utils/argcheck.py Add config knobs (enable_ema, ema_decay, ema_ckpt_keep, ema_full_validation) and cross-section validation.
source/tests/pt/test_training.py Add EMA-focused tests (rotation, restart persistence, EMA full validation output separation).
source/tests/pt/test_validation.py Update tests to match FullValidator arg rename (train_infosstate_store).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/pt/train/training.py Outdated
Comment thread deepmd/pt/train/training.py Outdated
Comment thread deepmd/utils/argcheck.py Outdated
Comment thread deepmd/pt/train/ema.py Outdated
Comment thread deepmd/pt/train/training.py Outdated
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 25, 2026

📝 Walkthrough

Walkthrough

Adds Exponential Moving Average (EMA) support: new ModelEMA utility, training integration for per-step EMA updates and EMA-aware checkpointing/validation, config/schema additions for EMA, and tests covering EMA behavior, restart, and artifact separation.

Changes

Cohort / File(s) Summary
EMA Core Module
deepmd/pt/train/ema.py
New module introducing ModelEMA, helpers get_ema_checkpoint_prefix and get_ema_validation_log_path, deterministic float-parameter collection, in-place EMA updates, strict state_dict/load_state_dict semantics, and apply_shadow context manager.
Training Loop Integration
deepmd/pt/train/training.py
Trainer now manages EMA lifecycle (init, per-step update, resume/load), embeds EMA state in checkpoints conditionally, adds save_ema_model, extends save_model API with EMA flags, centralizes checkpoint writing/pruning, and wires EMA full-validation flow.
Validation Framework
deepmd/pt/train/validation.py
FullValidator generalized to accept state_store, configurable keys/prefixes, optional full_val_file, model_eval_context, dynamic best-checkpoint glob/pattern builders, and suppression option for best-save logging.
Configuration Validation
deepmd/utils/argcheck.py
Adds training.enable_ema, training.ema_decay, training.ema_ckpt_keep, and validating.ema_full_validation; validates ranges and cross-field constraints (e.g., disallow EMA when zero_stage >= 2).
Tests — Training
source/tests/pt/test_training.py
New EMA-focused tests covering checkpoint rotation/cleanup, EMA persistence across restart (shadow params + validation state), bias update propagation into EMA checkpoints, and EMA full-validation artifact separation and logs.
Tests — Validation
source/tests/pt/test_validation.py
Updated call sites to construct FullValidator with state_store instead of train_infos (no behavior change).

Sequence Diagram

sequenceDiagram
    participant Trainer
    participant Model
    participant ModelEMA
    participant Checkpointer
    participant FullValidator

    Trainer->>ModelEMA: __init__(model, decay)
    Trainer->>Trainer: start training loop

    loop Per training step
        Trainer->>Model: forward/backward
        Trainer->>ModelEMA: update(model)
        ModelEMA->>ModelEMA: lerp_ shadow params (in-place)
    end

    alt Save regular + EMA checkpoints
        Trainer->>Checkpointer: save_model(use_ema_weights=False,...)
        Trainer->>ModelEMA: state_dict()
        Trainer->>Checkpointer: save_ema_model(...)
    end

    alt EMA full-validation
        Trainer->>ModelEMA: apply_shadow(model)
        ModelEMA->>Model: copy shadow params (device/dtype matched)
        Trainer->>FullValidator: evaluate_all_systems(model)
        FullValidator->>FullValidator: compute metrics, persist to state_store
        ModelEMA->>Model: restore original params
        Trainer->>Checkpointer: save EMA best checkpoint
    end

    alt Resume from checkpoint
        Trainer->>Checkpointer: load checkpoint
        Trainer->>ModelEMA: load_state_dict(ema_state)
        Trainer->>Trainer: continue training
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 56.14% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat(pt): add ema shadow model' accurately and concisely describes the primary change in the pull request, which introduces a new EMA (Exponential Moving Average) utility module with ModelEMA class and supporting functionality.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
Review rate limit: 7/8 reviews remaining, refill in 7 minutes and 30 seconds.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt/train/training.py (1)

188-205: ⚠️ Potential issue | 🟡 Minor

Validate ema_decay range.

A decay outside [0, 1] (e.g., negative or > 1) would cause unstable or unbounded shadow updates via lerp_. Consider asserting bounds here (or in argcheck) to fail fast with a clear message.

🛡️ Proposed defensive check
         self.ema_decay = float(training_params.get("ema_decay", 0.999))
+        if self.enable_ema and not (0.0 <= self.ema_decay <= 1.0):
+            raise ValueError(
+                f"training.ema_decay must lie in [0, 1], got {self.ema_decay}."
+            )
         self.ema_ckpt_keep = int(training_params.get("ema_ckpt_keep", 3))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/train/training.py` around lines 188 - 205, The ema_decay value from
training_params is not validated for range; add a defensive check after
computing self.ema_decay (e.g., immediately after self.ema_decay =
float(training_params.get("ema_decay", 0.999))) that asserts 0.0 <=
self.ema_decay <= 1.0 and raises a ValueError with a clear message (mentioning
training.ema_decay or self.ema_decay) if out of bounds so invalid decay values
are rejected early; keep this check alongside the existing zero_stage/enable_ema
validations.
🧹 Nitpick comments (4)
deepmd/pt/train/ema.py (2)

56-66: Type hint inconsistent with supported inputs.

_named_model_parameters and apply_shadow/update all accept torch.nn.Module | dict[str, torch.nn.Module], and Trainer.__init__ passes self.model (which is a dict in multi-task mode) into ModelEMA(...). The constructor's type hint should match to avoid misleading IDE/type checks.

♻️ Proposed type hint update
     def __init__(
         self,
-        model: torch.nn.Module,
+        model: torch.nn.Module | dict[str, torch.nn.Module],
         decay: float,
         state: dict[str, Any] | None = None,
     ) -> None:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/train/ema.py` around lines 56 - 66, The constructor signature for
ModelEMA.__init__ currently types the model parameter as torch.nn.Module but the
class methods (_named_model_parameters, apply_shadow, update) accept
torch.nn.Module | dict[str, torch.nn.Module] and Trainer may pass a dict in
multi-task mode; update the type hint of the model parameter in __init__ to
accept torch.nn.Module | dict[str, torch.nn.Module] (and adjust any related
local annotations like shadow_params if necessary) so IDE/type-checkers reflect
supported inputs and avoid false positives when Trainer passes a dict.

68-99: Optionally skip frozen (requires_grad=False) parameters.

Frozen parameters don't update during training, so keeping EMA shadows of them only consumes memory without benefit. Consider filtering them out unless there's a specific reason to track them (e.g., BN stats are buffers, not parameters, so they aren't affected either way).

♻️ Proposed filter
         return [
             (name, param)
             for name, param in model.named_parameters()
-            if torch.is_floating_point(param)
+            if torch.is_floating_point(param) and param.requires_grad
         ]

(and the symmetric change in the dict branch)

Note: if you intend to support parameters that toggle requires_grad during training, keep the current behavior.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/train/ema.py` around lines 68 - 99, The current
_named_model_parameters returns all floating-point parameters, including frozen
ones; update it to optionally skip parameters with requires_grad==False by
adding a boolean flag on the EMA class (e.g., track_frozen or
track_requires_grad) and filter out params where not param.requires_grad in both
the dict branch and the single-module branch; ensure _clone_model_parameters
continues to call _named_model_parameters so the cloning respects the new flag
and document the flag’s default (True to preserve current behavior or False to
save memory) and its use in the EMA constructor.
deepmd/pt/train/training.py (1)

1807-1870: Nit: move include_optimizer=False branch out of zero_stage==1 consolidate path.

Minor readability observation — under zero_stage == 1, consolidate_state_dict(to=0) is a collective call; gating it behind include_optimizer correctly skips it for EMA checkpoints, so rank>0 calls don't participate in a consolidation they'd otherwise need to attend. Worth a brief comment noting that this is intentional (EMA path is non-collective by design under ZeRO-1), since at a glance it may look like a missing collective.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/train/training.py` around lines 1807 - 1870, The
include_optimizer=False branch should be handled outside the zero_stage == 1
consolidate path and documented: in _collect_checkpoint_states, ensure you only
call self.optimizer.consolidate_state_dict(to=0) and build optim_state for
ZeRO-1 when include_optimizer is True (so non-optimizer/EMA checkpoints skip the
collective), and move the include_optimizer=False branch out of the consolidate
block; add a short comment near zero_stage == 1 explaining that skipping
consolidation for EMA checkpoints is intentional because consolidate_state_dict
is a collective call and must not be invoked when include_optimizer is False.
deepmd/pt/train/validation.py (1)

227-227: Use explicit is None check for model_eval_context.

model_eval_context or nullcontext will silently substitute nullcontext for any falsy callable, which is typically not intended for callable parameters. Prefer an explicit identity check:

♻️ Proposed change
-        self.model_eval_context = model_eval_context or nullcontext
+        self.model_eval_context = (
+            model_eval_context if model_eval_context is not None else nullcontext
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/train/validation.py` at line 227, The current assignment uses a
truthiness check which replaces any falsy callable; change the assignment of
self.model_eval_context so it only substitutes nullcontext when
model_eval_context is exactly None (e.g., use an explicit identity check or a
conditional expression based on "is None") rather than using "or", so that valid
but falsy callables are preserved; update the assignment where
self.model_eval_context is set and keep the nullcontext symbol as the fallback.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt/train/ema.py`:
- Around line 118-134: load_state_dict currently unconditionally overwrites
self.decay from EMA_DECAY_KEY which silently ignores any decay set from the
config (e.g. training.ema_decay); change load_state_dict to compare the
checkpoint decay (state[EMA_DECAY_KEY]) with the current self.decay and if they
differ emit a clear warning (using the module's existing logger or Python
logging) and keep the config-provided self.decay (do not overwrite), otherwise
fall back to setting self.decay from the checkpoint if self.decay is not already
set; refer to load_state_dict, EMA_DECAY_KEY and self.decay to locate the
change.

In `@deepmd/utils/argcheck.py`:
- Around line 4033-4057: The schema currently allows enable_ema even when
zero_stage >= 2; add a cross-field validation in the training normalize/extra
check (the function/method that performs training_extra_check or normalize
within deepmd/utils/argcheck.py) to raise a clear validation error when
training.enable_ema is true and training.zero_stage is >= 2; locate the Argument
block for "enable_ema"/"zero_stage" and implement the same rejection logic as in
deepmd/pt/train/training.py (check the normalized config object, and
raise/return an error message like "enable_ema is not supported with ZeRO stage
2 or 3") so invalid configs fail during normalization rather than at trainer
build time.

In `@source/tests/pt/test_training.py`:
- Around line 912-1014: Add a 60s timeout around the real training/validation
runs in the EMA tests so CI cannot hang: import concurrent.futures and replace
direct calls to trainer.run() in test_ema_checkpoint_rotation,
test_restart_restores_ema_state (if any trainer.run there) and
test_ema_full_validation_writes_separate_outputs with submitting trainer.run to
a ThreadPoolExecutor and calling future.result(timeout=60), raising/failing the
test if a TimeoutError occurs; ensure the helper uses the existing get_trainer
and trainer.run symbols and cleans up the executor so tests fail fast on hangs.
- Around line 947-958: The EMA checkpoint cleanup currently orders files by
Path.stat().st_mtime which can mis-rotate when multiple saves share a timestamp;
update the retention logic used by the EMA-save routine (the code that lists
files for ema_save_ckpt/ checkpoint_ema and deletes the oldest) to parse the
monotonic step/index from the filename (e.g., extract the integer from the
suffix pattern like "-{step}.pt") and sort by that integer so deletions remove
the lowest step numbers; if parsing fails for a file, fall back to st_mtime as a
secondary key, then enforce the existing keep-N retention behavior.

---

Outside diff comments:
In `@deepmd/pt/train/training.py`:
- Around line 188-205: The ema_decay value from training_params is not validated
for range; add a defensive check after computing self.ema_decay (e.g.,
immediately after self.ema_decay = float(training_params.get("ema_decay",
0.999))) that asserts 0.0 <= self.ema_decay <= 1.0 and raises a ValueError with
a clear message (mentioning training.ema_decay or self.ema_decay) if out of
bounds so invalid decay values are rejected early; keep this check alongside the
existing zero_stage/enable_ema validations.

---

Nitpick comments:
In `@deepmd/pt/train/ema.py`:
- Around line 56-66: The constructor signature for ModelEMA.__init__ currently
types the model parameter as torch.nn.Module but the class methods
(_named_model_parameters, apply_shadow, update) accept torch.nn.Module |
dict[str, torch.nn.Module] and Trainer may pass a dict in multi-task mode;
update the type hint of the model parameter in __init__ to accept
torch.nn.Module | dict[str, torch.nn.Module] (and adjust any related local
annotations like shadow_params if necessary) so IDE/type-checkers reflect
supported inputs and avoid false positives when Trainer passes a dict.
- Around line 68-99: The current _named_model_parameters returns all
floating-point parameters, including frozen ones; update it to optionally skip
parameters with requires_grad==False by adding a boolean flag on the EMA class
(e.g., track_frozen or track_requires_grad) and filter out params where not
param.requires_grad in both the dict branch and the single-module branch; ensure
_clone_model_parameters continues to call _named_model_parameters so the cloning
respects the new flag and document the flag’s default (True to preserve current
behavior or False to save memory) and its use in the EMA constructor.

In `@deepmd/pt/train/training.py`:
- Around line 1807-1870: The include_optimizer=False branch should be handled
outside the zero_stage == 1 consolidate path and documented: in
_collect_checkpoint_states, ensure you only call
self.optimizer.consolidate_state_dict(to=0) and build optim_state for ZeRO-1
when include_optimizer is True (so non-optimizer/EMA checkpoints skip the
collective), and move the include_optimizer=False branch out of the consolidate
block; add a short comment near zero_stage == 1 explaining that skipping
consolidation for EMA checkpoints is intentional because consolidate_state_dict
is a collective call and must not be invoked when include_optimizer is False.

In `@deepmd/pt/train/validation.py`:
- Line 227: The current assignment uses a truthiness check which replaces any
falsy callable; change the assignment of self.model_eval_context so it only
substitutes nullcontext when model_eval_context is exactly None (e.g., use an
explicit identity check or a conditional expression based on "is None") rather
than using "or", so that valid but falsy callables are preserved; update the
assignment where self.model_eval_context is set and keep the nullcontext symbol
as the fallback.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 88f418d0-b9c0-414c-b927-078b70586509

📥 Commits

Reviewing files that changed from the base of the PR and between 5c22e17 and 41b70c4.

📒 Files selected for processing (6)
  • deepmd/pt/train/ema.py
  • deepmd/pt/train/training.py
  • deepmd/pt/train/validation.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_training.py
  • source/tests/pt/test_validation.py

Comment thread deepmd/pt/train/ema.py
Comment thread deepmd/utils/argcheck.py
Comment thread source/tests/pt/test_training.py
Comment thread source/tests/pt/test_training.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
deepmd/pt/train/training.py (2)

1646-1655: Minor: missing log line for EMA checkpoint save (asymmetry with the regular save above).

Lines 1641–1645 log "Saved model to …" and write the checkpoint pointer file for the regular ckpt, but the EMA branch only emits the symlink with no log. Adding a parallel log.info(f"Saved EMA model to {self.latest_ema_model}") (rank-0 guarded) would make EMA progress visible in the logs and align the two paths.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/train/training.py` around lines 1646 - 1655, The EMA branch lacks a
log entry after saving the EMA checkpoint; after calling
self.save_ema_model(self.latest_ema_model, lr=cur_lr, step=_step_id) add a
rank-0 guarded log statement (e.g. log.info or self.logger.info consistent with
the file) like "Saved EMA model to {self.latest_ema_model}" so EMA saves are
visible in logs; keep the guard matching the existing check (self.rank == 0 or
dist.get_rank() == 0) and place the log before/after the
symlink_prefix_files(self.latest_ema_model.stem, self.ema_save_ckpt) call.

1053-1058: Avoid hardcoding "best_ema.ckpt"; derive it from BEST_CKPT_PREFIX (or via get_ema_checkpoint_prefix) for consistency.

The non-EMA full validator picks up BEST_CKPT_PREFIX from FullValidator's defaults (snippet 2), but here the EMA variant is hardcoded. If the project ever changes the canonical best-ckpt prefix, the EMA path will silently drift. Consider:

-            best_checkpoint_prefix="best_ema.ckpt",
+            best_checkpoint_prefix=get_ema_checkpoint_prefix(BEST_CKPT_PREFIX),

(and import BEST_CKPT_PREFIX from deepmd.pt.train.validation). This also keeps the suffix derivation rule (_ema) in one place.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/train/training.py` around lines 1053 - 1058, The code hardcodes
"best_ema.ckpt" for the EMA validator's best_checkpoint_prefix; replace that
literal by deriving the prefix from the canonical constant or helper (use
BEST_CKPT_PREFIX or call get_ema_checkpoint_prefix()) so the EMA prefix stays in
sync with FullValidator's default; update the call where best_checkpoint_prefix
is passed (the block using model_eval_context/self.model_ema.apply_shadow) and
add the necessary import from deepmd.pt.train.validation for BEST_CKPT_PREFIX or
get_ema_checkpoint_prefix.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt/train/training.py`:
- Around line 1735-1744: The EMA checkpoint saved after training can be stale
because model_change_out_bias mutates self.model but self.model_ema still holds
pre-adjustment biases; before calling save_ema_model (used with
_collect_checkpoint_states(use_ema_weights=True)) refresh the EMA shadow's
out-bias entries from the just-modified self.model so the EMA snapshot reflects
the bias correction: detect change_bias_after_training (or check whether
model_change_out_bias ran), and perform a one-shot copy of the relevant out-bias
parameter tensors from self.model to self.model_ema (not a decayed update) right
before save_ema_model is invoked for latest_ema_model/ema_save_ckpt;
alternatively, if you prefer not to modify the EMA shadow, skip saving the EMA
checkpoint and emit a warning when change_bias_after_training is enabled.

---

Nitpick comments:
In `@deepmd/pt/train/training.py`:
- Around line 1646-1655: The EMA branch lacks a log entry after saving the EMA
checkpoint; after calling self.save_ema_model(self.latest_ema_model, lr=cur_lr,
step=_step_id) add a rank-0 guarded log statement (e.g. log.info or
self.logger.info consistent with the file) like "Saved EMA model to
{self.latest_ema_model}" so EMA saves are visible in logs; keep the guard
matching the existing check (self.rank == 0 or dist.get_rank() == 0) and place
the log before/after the symlink_prefix_files(self.latest_ema_model.stem,
self.ema_save_ckpt) call.
- Around line 1053-1058: The code hardcodes "best_ema.ckpt" for the EMA
validator's best_checkpoint_prefix; replace that literal by deriving the prefix
from the canonical constant or helper (use BEST_CKPT_PREFIX or call
get_ema_checkpoint_prefix()) so the EMA prefix stays in sync with
FullValidator's default; update the call where best_checkpoint_prefix is passed
(the block using model_eval_context/self.model_ema.apply_shadow) and add the
necessary import from deepmd.pt.train.validation for BEST_CKPT_PREFIX or
get_ema_checkpoint_prefix.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 00134712-10e0-4c17-8c7b-a47612a2d1c1

📥 Commits

Reviewing files that changed from the base of the PR and between 41b70c4 and 96b2817.

📒 Files selected for processing (4)
  • deepmd/pt/train/ema.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_training.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/train/ema.py

Comment thread deepmd/pt/train/training.py
Comment thread source/tests/pt/test_training.py Fixed
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 25, 2026

Codecov Report

❌ Patch coverage is 86.61417% with 34 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.36%. Comparing base (c38dc00) to head (8d155de).

Files with missing lines Patch % Lines
deepmd/pt/train/training.py 82.50% 21 Missing ⚠️
deepmd/pt/train/ema.py 87.20% 11 Missing ⚠️
deepmd/utils/argcheck.py 88.23% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5420      +/-   ##
==========================================
- Coverage   82.42%   82.36%   -0.06%     
==========================================
  Files         824      825       +1     
  Lines       87510    87718     +208     
  Branches     4197     4197              
==========================================
+ Hits        72127    72250     +123     
- Misses      14108    14193      +85     
  Partials     1275     1275              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/utils/argcheck.py`:
- Around line 4422-4429: The EMA prerequisite check (ema_full_validation_enabled
and training_params.get("enable_ema", False)) must be evaluated before the early
return that checks validating.get("full_val_start", 0.0) == 1.0; move the block
that raises ValueError for validating.ema_full_validation to occur immediately
after verifying full_validation_enabled (i.e., before the full_val_start return)
so that when ema_full_validation_enabled is true and training.enable_ema is
false the code consistently raises the ValueError instead of being skipped by
the full_val_start early return.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 5a4d0308-c76d-4c6b-9ae6-f5de943a18bd

📥 Commits

Reviewing files that changed from the base of the PR and between 9300fac and 8d155de.

📒 Files selected for processing (6)
  • deepmd/pt/train/ema.py
  • deepmd/pt/train/training.py
  • deepmd/pt/train/validation.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_training.py
  • source/tests/pt/test_validation.py
✅ Files skipped from review due to trivial changes (2)
  • source/tests/pt/test_validation.py
  • deepmd/pt/train/ema.py

Comment thread deepmd/utils/argcheck.py
Comment on lines +4422 to +4429
if not full_validation_enabled:
return
if float(validating.get("full_val_start", 0.0)) == 1.0:
return
if ema_full_validation_enabled and not training_params.get("enable_ema", False):
raise ValueError(
"validating.ema_full_validation requires `training.enable_ema=true`."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Keep the EMA prerequisite check ahead of the full_val_start early return.

With the current order, validating.full_validation=true, validating.full_val_start=1.0, and validating.ema_full_validation=true skips the new training.enable_ema guard entirely. That lets a self-contradictory EMA config normalize successfully instead of being rejected consistently.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/utils/argcheck.py` around lines 4422 - 4429, The EMA prerequisite
check (ema_full_validation_enabled and training_params.get("enable_ema", False))
must be evaluated before the early return that checks
validating.get("full_val_start", 0.0) == 1.0; move the block that raises
ValueError for validating.ema_full_validation to occur immediately after
verifying full_validation_enabled (i.e., before the full_val_start return) so
that when ema_full_validation_enabled is true and training.enable_ema is false
the code consistently raises the ValueError instead of being skipped by the
full_val_start early return.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants