feat(pt): add ema shadow model#5420
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds an exponential moving average (EMA) “shadow” model to the PyTorch training stack, including EMA checkpointing, restart support, and optional EMA-based full validation outputs.
Changes:
- Introduce
ModelEMAwith serialization, shadow-weight application, and filename helpers for EMA checkpoints/logs. - Extend the PyTorch trainer to maintain/update EMA state, save periodic EMA checkpoints, and optionally run full validation on EMA weights with separate logs/best checkpoints.
- Update argcheck schema/validation and add unit tests covering EMA rotation, restart restore, and EMA full-validation outputs.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
deepmd/pt/train/ema.py |
New EMA implementation (shadow params, state dict, apply-shadow context). |
deepmd/pt/train/training.py |
Trainer integration: EMA lifecycle, EMA validation, checkpoint format/rotation refactor. |
deepmd/pt/train/validation.py |
Generalize FullValidator to support alternate state stores, prefixes, log paths, and eval contexts (used by EMA validation). |
deepmd/utils/argcheck.py |
Add config knobs (enable_ema, ema_decay, ema_ckpt_keep, ema_full_validation) and cross-section validation. |
source/tests/pt/test_training.py |
Add EMA-focused tests (rotation, restart persistence, EMA full validation output separation). |
source/tests/pt/test_validation.py |
Update tests to match FullValidator arg rename (train_infos → state_store). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
📝 WalkthroughWalkthroughAdds Exponential Moving Average (EMA) support: new Changes
Sequence DiagramsequenceDiagram
participant Trainer
participant Model
participant ModelEMA
participant Checkpointer
participant FullValidator
Trainer->>ModelEMA: __init__(model, decay)
Trainer->>Trainer: start training loop
loop Per training step
Trainer->>Model: forward/backward
Trainer->>ModelEMA: update(model)
ModelEMA->>ModelEMA: lerp_ shadow params (in-place)
end
alt Save regular + EMA checkpoints
Trainer->>Checkpointer: save_model(use_ema_weights=False,...)
Trainer->>ModelEMA: state_dict()
Trainer->>Checkpointer: save_ema_model(...)
end
alt EMA full-validation
Trainer->>ModelEMA: apply_shadow(model)
ModelEMA->>Model: copy shadow params (device/dtype matched)
Trainer->>FullValidator: evaluate_all_systems(model)
FullValidator->>FullValidator: compute metrics, persist to state_store
ModelEMA->>Model: restore original params
Trainer->>Checkpointer: save EMA best checkpoint
end
alt Resume from checkpoint
Trainer->>Checkpointer: load checkpoint
Trainer->>ModelEMA: load_state_dict(ema_state)
Trainer->>Trainer: continue training
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Review rate limit: 7/8 reviews remaining, refill in 7 minutes and 30 seconds.Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/train/training.py (1)
188-205:⚠️ Potential issue | 🟡 MinorValidate
ema_decayrange.A decay outside
[0, 1](e.g., negative or > 1) would cause unstable or unbounded shadow updates vialerp_. Consider asserting bounds here (or inargcheck) to fail fast with a clear message.🛡️ Proposed defensive check
self.ema_decay = float(training_params.get("ema_decay", 0.999)) + if self.enable_ema and not (0.0 <= self.ema_decay <= 1.0): + raise ValueError( + f"training.ema_decay must lie in [0, 1], got {self.ema_decay}." + ) self.ema_ckpt_keep = int(training_params.get("ema_ckpt_keep", 3))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/training.py` around lines 188 - 205, The ema_decay value from training_params is not validated for range; add a defensive check after computing self.ema_decay (e.g., immediately after self.ema_decay = float(training_params.get("ema_decay", 0.999))) that asserts 0.0 <= self.ema_decay <= 1.0 and raises a ValueError with a clear message (mentioning training.ema_decay or self.ema_decay) if out of bounds so invalid decay values are rejected early; keep this check alongside the existing zero_stage/enable_ema validations.
🧹 Nitpick comments (4)
deepmd/pt/train/ema.py (2)
56-66: Type hint inconsistent with supported inputs.
_named_model_parametersandapply_shadow/updateall accepttorch.nn.Module | dict[str, torch.nn.Module], andTrainer.__init__passesself.model(which is a dict in multi-task mode) intoModelEMA(...). The constructor's type hint should match to avoid misleading IDE/type checks.♻️ Proposed type hint update
def __init__( self, - model: torch.nn.Module, + model: torch.nn.Module | dict[str, torch.nn.Module], decay: float, state: dict[str, Any] | None = None, ) -> None:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/ema.py` around lines 56 - 66, The constructor signature for ModelEMA.__init__ currently types the model parameter as torch.nn.Module but the class methods (_named_model_parameters, apply_shadow, update) accept torch.nn.Module | dict[str, torch.nn.Module] and Trainer may pass a dict in multi-task mode; update the type hint of the model parameter in __init__ to accept torch.nn.Module | dict[str, torch.nn.Module] (and adjust any related local annotations like shadow_params if necessary) so IDE/type-checkers reflect supported inputs and avoid false positives when Trainer passes a dict.
68-99: Optionally skip frozen (requires_grad=False) parameters.Frozen parameters don't update during training, so keeping EMA shadows of them only consumes memory without benefit. Consider filtering them out unless there's a specific reason to track them (e.g., BN stats are buffers, not parameters, so they aren't affected either way).
♻️ Proposed filter
return [ (name, param) for name, param in model.named_parameters() - if torch.is_floating_point(param) + if torch.is_floating_point(param) and param.requires_grad ](and the symmetric change in the dict branch)
Note: if you intend to support parameters that toggle
requires_gradduring training, keep the current behavior.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/ema.py` around lines 68 - 99, The current _named_model_parameters returns all floating-point parameters, including frozen ones; update it to optionally skip parameters with requires_grad==False by adding a boolean flag on the EMA class (e.g., track_frozen or track_requires_grad) and filter out params where not param.requires_grad in both the dict branch and the single-module branch; ensure _clone_model_parameters continues to call _named_model_parameters so the cloning respects the new flag and document the flag’s default (True to preserve current behavior or False to save memory) and its use in the EMA constructor.deepmd/pt/train/training.py (1)
1807-1870: Nit: moveinclude_optimizer=Falsebranch out of zero_stage==1 consolidate path.Minor readability observation — under
zero_stage == 1,consolidate_state_dict(to=0)is a collective call; gating it behindinclude_optimizercorrectly skips it for EMA checkpoints, so rank>0 calls don't participate in a consolidation they'd otherwise need to attend. Worth a brief comment noting that this is intentional (EMA path is non-collective by design under ZeRO-1), since at a glance it may look like a missing collective.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/training.py` around lines 1807 - 1870, The include_optimizer=False branch should be handled outside the zero_stage == 1 consolidate path and documented: in _collect_checkpoint_states, ensure you only call self.optimizer.consolidate_state_dict(to=0) and build optim_state for ZeRO-1 when include_optimizer is True (so non-optimizer/EMA checkpoints skip the collective), and move the include_optimizer=False branch out of the consolidate block; add a short comment near zero_stage == 1 explaining that skipping consolidation for EMA checkpoints is intentional because consolidate_state_dict is a collective call and must not be invoked when include_optimizer is False.deepmd/pt/train/validation.py (1)
227-227: Use explicitis Nonecheck formodel_eval_context.
model_eval_context or nullcontextwill silently substitutenullcontextfor any falsy callable, which is typically not intended for callable parameters. Prefer an explicit identity check:♻️ Proposed change
- self.model_eval_context = model_eval_context or nullcontext + self.model_eval_context = ( + model_eval_context if model_eval_context is not None else nullcontext + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` at line 227, The current assignment uses a truthiness check which replaces any falsy callable; change the assignment of self.model_eval_context so it only substitutes nullcontext when model_eval_context is exactly None (e.g., use an explicit identity check or a conditional expression based on "is None") rather than using "or", so that valid but falsy callables are preserved; update the assignment where self.model_eval_context is set and keep the nullcontext symbol as the fallback.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/train/ema.py`:
- Around line 118-134: load_state_dict currently unconditionally overwrites
self.decay from EMA_DECAY_KEY which silently ignores any decay set from the
config (e.g. training.ema_decay); change load_state_dict to compare the
checkpoint decay (state[EMA_DECAY_KEY]) with the current self.decay and if they
differ emit a clear warning (using the module's existing logger or Python
logging) and keep the config-provided self.decay (do not overwrite), otherwise
fall back to setting self.decay from the checkpoint if self.decay is not already
set; refer to load_state_dict, EMA_DECAY_KEY and self.decay to locate the
change.
In `@deepmd/utils/argcheck.py`:
- Around line 4033-4057: The schema currently allows enable_ema even when
zero_stage >= 2; add a cross-field validation in the training normalize/extra
check (the function/method that performs training_extra_check or normalize
within deepmd/utils/argcheck.py) to raise a clear validation error when
training.enable_ema is true and training.zero_stage is >= 2; locate the Argument
block for "enable_ema"/"zero_stage" and implement the same rejection logic as in
deepmd/pt/train/training.py (check the normalized config object, and
raise/return an error message like "enable_ema is not supported with ZeRO stage
2 or 3") so invalid configs fail during normalization rather than at trainer
build time.
In `@source/tests/pt/test_training.py`:
- Around line 912-1014: Add a 60s timeout around the real training/validation
runs in the EMA tests so CI cannot hang: import concurrent.futures and replace
direct calls to trainer.run() in test_ema_checkpoint_rotation,
test_restart_restores_ema_state (if any trainer.run there) and
test_ema_full_validation_writes_separate_outputs with submitting trainer.run to
a ThreadPoolExecutor and calling future.result(timeout=60), raising/failing the
test if a TimeoutError occurs; ensure the helper uses the existing get_trainer
and trainer.run symbols and cleans up the executor so tests fail fast on hangs.
- Around line 947-958: The EMA checkpoint cleanup currently orders files by
Path.stat().st_mtime which can mis-rotate when multiple saves share a timestamp;
update the retention logic used by the EMA-save routine (the code that lists
files for ema_save_ckpt/ checkpoint_ema and deletes the oldest) to parse the
monotonic step/index from the filename (e.g., extract the integer from the
suffix pattern like "-{step}.pt") and sort by that integer so deletions remove
the lowest step numbers; if parsing fails for a file, fall back to st_mtime as a
secondary key, then enforce the existing keep-N retention behavior.
---
Outside diff comments:
In `@deepmd/pt/train/training.py`:
- Around line 188-205: The ema_decay value from training_params is not validated
for range; add a defensive check after computing self.ema_decay (e.g.,
immediately after self.ema_decay = float(training_params.get("ema_decay",
0.999))) that asserts 0.0 <= self.ema_decay <= 1.0 and raises a ValueError with
a clear message (mentioning training.ema_decay or self.ema_decay) if out of
bounds so invalid decay values are rejected early; keep this check alongside the
existing zero_stage/enable_ema validations.
---
Nitpick comments:
In `@deepmd/pt/train/ema.py`:
- Around line 56-66: The constructor signature for ModelEMA.__init__ currently
types the model parameter as torch.nn.Module but the class methods
(_named_model_parameters, apply_shadow, update) accept torch.nn.Module |
dict[str, torch.nn.Module] and Trainer may pass a dict in multi-task mode;
update the type hint of the model parameter in __init__ to accept
torch.nn.Module | dict[str, torch.nn.Module] (and adjust any related local
annotations like shadow_params if necessary) so IDE/type-checkers reflect
supported inputs and avoid false positives when Trainer passes a dict.
- Around line 68-99: The current _named_model_parameters returns all
floating-point parameters, including frozen ones; update it to optionally skip
parameters with requires_grad==False by adding a boolean flag on the EMA class
(e.g., track_frozen or track_requires_grad) and filter out params where not
param.requires_grad in both the dict branch and the single-module branch; ensure
_clone_model_parameters continues to call _named_model_parameters so the cloning
respects the new flag and document the flag’s default (True to preserve current
behavior or False to save memory) and its use in the EMA constructor.
In `@deepmd/pt/train/training.py`:
- Around line 1807-1870: The include_optimizer=False branch should be handled
outside the zero_stage == 1 consolidate path and documented: in
_collect_checkpoint_states, ensure you only call
self.optimizer.consolidate_state_dict(to=0) and build optim_state for ZeRO-1
when include_optimizer is True (so non-optimizer/EMA checkpoints skip the
collective), and move the include_optimizer=False branch out of the consolidate
block; add a short comment near zero_stage == 1 explaining that skipping
consolidation for EMA checkpoints is intentional because consolidate_state_dict
is a collective call and must not be invoked when include_optimizer is False.
In `@deepmd/pt/train/validation.py`:
- Line 227: The current assignment uses a truthiness check which replaces any
falsy callable; change the assignment of self.model_eval_context so it only
substitutes nullcontext when model_eval_context is exactly None (e.g., use an
explicit identity check or a conditional expression based on "is None") rather
than using "or", so that valid but falsy callables are preserved; update the
assignment where self.model_eval_context is set and keep the nullcontext symbol
as the fallback.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 88f418d0-b9c0-414c-b927-078b70586509
📒 Files selected for processing (6)
deepmd/pt/train/ema.pydeepmd/pt/train/training.pydeepmd/pt/train/validation.pydeepmd/utils/argcheck.pysource/tests/pt/test_training.pysource/tests/pt/test_validation.py
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
deepmd/pt/train/training.py (2)
1646-1655: Minor: missing log line for EMA checkpoint save (asymmetry with the regular save above).Lines 1641–1645 log "Saved model to …" and write the
checkpointpointer file for the regular ckpt, but the EMA branch only emits the symlink with no log. Adding a parallellog.info(f"Saved EMA model to {self.latest_ema_model}")(rank-0 guarded) would make EMA progress visible in the logs and align the two paths.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/training.py` around lines 1646 - 1655, The EMA branch lacks a log entry after saving the EMA checkpoint; after calling self.save_ema_model(self.latest_ema_model, lr=cur_lr, step=_step_id) add a rank-0 guarded log statement (e.g. log.info or self.logger.info consistent with the file) like "Saved EMA model to {self.latest_ema_model}" so EMA saves are visible in logs; keep the guard matching the existing check (self.rank == 0 or dist.get_rank() == 0) and place the log before/after the symlink_prefix_files(self.latest_ema_model.stem, self.ema_save_ckpt) call.
1053-1058: Avoid hardcoding"best_ema.ckpt"; derive it fromBEST_CKPT_PREFIX(or viaget_ema_checkpoint_prefix) for consistency.The non-EMA full validator picks up
BEST_CKPT_PREFIXfromFullValidator's defaults (snippet 2), but here the EMA variant is hardcoded. If the project ever changes the canonical best-ckpt prefix, the EMA path will silently drift. Consider:- best_checkpoint_prefix="best_ema.ckpt", + best_checkpoint_prefix=get_ema_checkpoint_prefix(BEST_CKPT_PREFIX),(and import
BEST_CKPT_PREFIXfromdeepmd.pt.train.validation). This also keeps the suffix derivation rule (_ema) in one place.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/training.py` around lines 1053 - 1058, The code hardcodes "best_ema.ckpt" for the EMA validator's best_checkpoint_prefix; replace that literal by deriving the prefix from the canonical constant or helper (use BEST_CKPT_PREFIX or call get_ema_checkpoint_prefix()) so the EMA prefix stays in sync with FullValidator's default; update the call where best_checkpoint_prefix is passed (the block using model_eval_context/self.model_ema.apply_shadow) and add the necessary import from deepmd.pt.train.validation for BEST_CKPT_PREFIX or get_ema_checkpoint_prefix.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/train/training.py`:
- Around line 1735-1744: The EMA checkpoint saved after training can be stale
because model_change_out_bias mutates self.model but self.model_ema still holds
pre-adjustment biases; before calling save_ema_model (used with
_collect_checkpoint_states(use_ema_weights=True)) refresh the EMA shadow's
out-bias entries from the just-modified self.model so the EMA snapshot reflects
the bias correction: detect change_bias_after_training (or check whether
model_change_out_bias ran), and perform a one-shot copy of the relevant out-bias
parameter tensors from self.model to self.model_ema (not a decayed update) right
before save_ema_model is invoked for latest_ema_model/ema_save_ckpt;
alternatively, if you prefer not to modify the EMA shadow, skip saving the EMA
checkpoint and emit a warning when change_bias_after_training is enabled.
---
Nitpick comments:
In `@deepmd/pt/train/training.py`:
- Around line 1646-1655: The EMA branch lacks a log entry after saving the EMA
checkpoint; after calling self.save_ema_model(self.latest_ema_model, lr=cur_lr,
step=_step_id) add a rank-0 guarded log statement (e.g. log.info or
self.logger.info consistent with the file) like "Saved EMA model to
{self.latest_ema_model}" so EMA saves are visible in logs; keep the guard
matching the existing check (self.rank == 0 or dist.get_rank() == 0) and place
the log before/after the symlink_prefix_files(self.latest_ema_model.stem,
self.ema_save_ckpt) call.
- Around line 1053-1058: The code hardcodes "best_ema.ckpt" for the EMA
validator's best_checkpoint_prefix; replace that literal by deriving the prefix
from the canonical constant or helper (use BEST_CKPT_PREFIX or call
get_ema_checkpoint_prefix()) so the EMA prefix stays in sync with
FullValidator's default; update the call where best_checkpoint_prefix is passed
(the block using model_eval_context/self.model_ema.apply_shadow) and add the
necessary import from deepmd.pt.train.validation for BEST_CKPT_PREFIX or
get_ema_checkpoint_prefix.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 00134712-10e0-4c17-8c7b-a47612a2d1c1
📒 Files selected for processing (4)
deepmd/pt/train/ema.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_training.py
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/train/ema.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5420 +/- ##
==========================================
- Coverage 82.42% 82.36% -0.06%
==========================================
Files 824 825 +1
Lines 87510 87718 +208
Branches 4197 4197
==========================================
+ Hits 72127 72250 +123
- Misses 14108 14193 +85
Partials 1275 1275 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/utils/argcheck.py`:
- Around line 4422-4429: The EMA prerequisite check (ema_full_validation_enabled
and training_params.get("enable_ema", False)) must be evaluated before the early
return that checks validating.get("full_val_start", 0.0) == 1.0; move the block
that raises ValueError for validating.ema_full_validation to occur immediately
after verifying full_validation_enabled (i.e., before the full_val_start return)
so that when ema_full_validation_enabled is true and training.enable_ema is
false the code consistently raises the ValueError instead of being skipped by
the full_val_start early return.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 5a4d0308-c76d-4c6b-9ae6-f5de943a18bd
📒 Files selected for processing (6)
deepmd/pt/train/ema.pydeepmd/pt/train/training.pydeepmd/pt/train/validation.pydeepmd/utils/argcheck.pysource/tests/pt/test_training.pysource/tests/pt/test_validation.py
✅ Files skipped from review due to trivial changes (2)
- source/tests/pt/test_validation.py
- deepmd/pt/train/ema.py
| if not full_validation_enabled: | ||
| return | ||
| if float(validating.get("full_val_start", 0.0)) == 1.0: | ||
| return | ||
| if ema_full_validation_enabled and not training_params.get("enable_ema", False): | ||
| raise ValueError( | ||
| "validating.ema_full_validation requires `training.enable_ema=true`." | ||
| ) |
There was a problem hiding this comment.
Keep the EMA prerequisite check ahead of the full_val_start early return.
With the current order, validating.full_validation=true, validating.full_val_start=1.0, and validating.ema_full_validation=true skips the new training.enable_ema guard entirely. That lets a self-contradictory EMA config normalize successfully instead of being rejected consistently.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/utils/argcheck.py` around lines 4422 - 4429, The EMA prerequisite
check (ema_full_validation_enabled and training_params.get("enable_ema", False))
must be evaluated before the early return that checks
validating.get("full_val_start", 0.0) == 1.0; move the block that raises
ValueError for validating.ema_full_validation to occur immediately after
verifying full_validation_enabled (i.e., before the full_val_start return) so
that when ema_full_validation_enabled is true and training.enable_ema is false
the code consistently raises the ValueError instead of being skipped by the
full_val_start early return.
Summary by CodeRabbit
New Features
Configuration
Validation
Tests
Chores