diff --git a/deepmd/pt/train/ema.py b/deepmd/pt/train/ema.py new file mode 100644 index 0000000000..71054d0abb --- /dev/null +++ b/deepmd/pt/train/ema.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later + +from __future__ import ( + annotations, +) + +import logging +from contextlib import ( + contextmanager, +) +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch + +if TYPE_CHECKING: + from collections.abc import ( + Iterator, + ) + +EMA_CHECKPOINT_KEY = "ema" +EMA_DECAY_KEY = "decay" +EMA_MODEL_STATE_KEY = "model" +EMA_VALIDATION_STATE_KEY = "validation_state" + +log = logging.getLogger(__name__) + + +def _append_suffix(path_like: str | Path, suffix: str) -> Path: + """Append a suffix before the final file suffix when present.""" + path = Path(path_like) + if path.suffix: + return path.with_name(f"{path.stem}{suffix}{path.suffix}") + return path.with_name(f"{path.name}{suffix}") + + +def get_ema_checkpoint_prefix(save_ckpt: str | Path) -> str: + """Derive the EMA checkpoint prefix from the regular checkpoint prefix.""" + return str(_append_suffix(save_ckpt, "_ema")) + + +def get_ema_validation_log_path(full_val_file: str | Path) -> Path: + """Derive the EMA validation log path from the regular validation log path.""" + return _append_suffix(full_val_file, "_ema") + + +class ModelEMA: + """Maintain an exponential moving average of model parameters.""" + + def __init__( + self, + model: torch.nn.Module | dict[str, torch.nn.Module], + decay: float, + state: dict[str, Any] | None = None, + ) -> None: + self.decay = float(decay) + self.shadow_params = self._clone_model_parameters(model) + self.validation_state: dict[str, Any] = {} + if state is not None: + self.load_state_dict(state) + + @staticmethod + def _named_model_parameters( + model: torch.nn.Module | dict[str, torch.nn.Module], + ) -> list[tuple[str, torch.nn.Parameter]]: + """Collect all floating-point model parameters in a deterministic order.""" + if isinstance(model, dict): + named_parameters = [] + for model_key in sorted(model): + named_parameters.extend( + [ + (f"{model_key}.{name}", param) + for name, param in model[model_key].named_parameters() + if torch.is_floating_point(param) + ] + ) + return named_parameters + return [ + (name, param) + for name, param in model.named_parameters() + if torch.is_floating_point(param) + ] + + def _clone_model_parameters( + self, + model: torch.nn.Module | dict[str, torch.nn.Module], + ) -> dict[str, torch.Tensor]: + """Clone model parameters to initialize the EMA shadow state.""" + with torch.no_grad(): + return { + name: param.detach().clone() + for name, param in self._named_model_parameters(model) + } + + def update(self, model: torch.nn.Module | dict[str, torch.nn.Module]) -> None: + """Update EMA shadow parameters from the current model parameters.""" + with torch.no_grad(): + for name, param in self._named_model_parameters(model): + self.shadow_params[name].lerp_(param.detach(), weight=1.0 - self.decay) + + def state_dict(self) -> dict[str, Any]: + """Serialize EMA state for restart.""" + return { + EMA_DECAY_KEY: self.decay, + EMA_MODEL_STATE_KEY: { + name: tensor.detach().cpu().clone() + for name, tensor in self.shadow_params.items() + }, + EMA_VALIDATION_STATE_KEY: deepcopy(self.validation_state), + } + + def load_state_dict(self, state: dict[str, Any]) -> None: + """Restore EMA shadow parameters and validator state.""" + if EMA_DECAY_KEY in state: + checkpoint_decay = float(state[EMA_DECAY_KEY]) + if checkpoint_decay != self.decay: + log.warning( + "Overriding training.ema_decay=%s with EMA checkpoint decay=%s.", + self.decay, + checkpoint_decay, + ) + self.decay = checkpoint_decay + model_state = state.get(EMA_MODEL_STATE_KEY, {}) + if not isinstance(model_state, dict): + raise TypeError("EMA checkpoint field `model` must be a dict.") + + current_keys = set(self.shadow_params) + loaded_keys = set(model_state) + missing_keys = sorted(current_keys - loaded_keys) + unexpected_keys = sorted(loaded_keys - current_keys) + if missing_keys or unexpected_keys: + raise KeyError( + "EMA checkpoint parameter keys do not match the current model. " + f"Missing keys: {missing_keys[:5]}, unexpected keys: {unexpected_keys[:5]}." + ) + + with torch.no_grad(): + for name, shadow_param in self.shadow_params.items(): + loaded_param = model_state[name] + if not isinstance(loaded_param, torch.Tensor): + raise TypeError( + f"EMA checkpoint tensor for {name!r} must be a torch.Tensor." + ) + if loaded_param.shape != shadow_param.shape: + raise ValueError( + "EMA checkpoint parameter shape does not match the current " + f"model for {name!r}: expected {tuple(shadow_param.shape)}, " + f"got {tuple(loaded_param.shape)}." + ) + shadow_param.copy_( + loaded_param.to( + device=shadow_param.device, + dtype=shadow_param.dtype, + ) + ) + + validation_state = state.get(EMA_VALIDATION_STATE_KEY, {}) + if validation_state is None: + validation_state = {} + if not isinstance(validation_state, dict): + raise TypeError("EMA checkpoint field `validation_state` must be a dict.") + self.validation_state = deepcopy(validation_state) + + @contextmanager + def apply_shadow( + self, + model: torch.nn.Module | dict[str, torch.nn.Module], + ) -> Iterator[None]: + """Temporarily replace model parameters with the EMA shadow state.""" + backups: dict[str, torch.Tensor] = {} + with torch.no_grad(): + for name, param in self._named_model_parameters(model): + backups[name] = param.detach().clone() + param.copy_( + self.shadow_params[name].to( + device=param.device, + dtype=param.dtype, + ) + ) + try: + yield + finally: + with torch.no_grad(): + for name, param in self._named_model_parameters(model): + param.copy_(backups[name]) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 91a4705207..d902ac34bb 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -9,6 +9,9 @@ Generator, Iterable, ) +from contextlib import ( + nullcontext, +) from copy import ( deepcopy, ) @@ -54,6 +57,12 @@ KFOptimizerWrapper, LKFOptimizer, ) +from deepmd.pt.train.ema import ( + EMA_CHECKPOINT_KEY, + ModelEMA, + get_ema_checkpoint_prefix, + get_ema_validation_log_path, +) from deepmd.pt.train.validation import ( FullValidator, resolve_full_validation_start_step, @@ -180,6 +189,10 @@ def __init__( self.save_ckpt = training_params.get("save_ckpt", "model.ckpt") self.save_freq = training_params.get("save_freq", 1000) self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5) + self.enable_ema = bool(training_params.get("enable_ema", False)) + self.ema_decay = float(training_params.get("ema_decay", 0.999)) + self.ema_ckpt_keep = int(training_params.get("ema_ckpt_keep", 3)) + self.ema_save_ckpt = get_ema_checkpoint_prefix(self.save_ckpt) self.display_in_training = training_params.get("disp_training", True) self.timing_in_training = training_params.get("time_training", True) self.change_bias_after_training = training_params.get( @@ -190,6 +203,10 @@ def __init__( raise ValueError( f"training.zero_stage must be 0, 1, 2, or 3, got {self.zero_stage}" ) + if self.enable_ema and self.zero_stage >= 2: + raise ValueError( + "training.enable_ema currently only supports training.zero_stage < 2." + ) if self.zero_stage > 0 and not self.is_distributed: self.zero_stage = 0 if self.zero_stage > 0 and self.change_bias_after_training: @@ -654,12 +671,18 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: # resuming and finetune optimizer_state_dict = None + ema_state_dict = None if resuming: log.info(f"Resuming from {resume_model}.") state_dict = torch.load( resume_model, map_location=DEVICE, weights_only=True ) if "model" in state_dict: + ema_state_dict = ( + state_dict.get(EMA_CHECKPOINT_KEY) + if finetune_model is None and self.restart_training + else None + ) optimizer_state_dict = ( state_dict["optimizer"] if finetune_model is None else None ) @@ -949,6 +972,14 @@ def single_model_finetune( last_epoch=self.start_step - 1, ) + self.model_ema = None + if self.enable_ema: + self.model_ema = ModelEMA( + self.model, + decay=self.ema_decay, + state=ema_state_dict, + ) + if self.zero_stage > 0 and self.rank == 0: if self.zero_stage == 1: log.info("Enabled DDP + ZeRO Stage-1 Optimizer State Sharding.") @@ -967,11 +998,18 @@ def single_model_finetune( self.enable_profiler = training_params.get("enable_profiler", False) self.profiling = training_params.get("profiling", False) self.profiling_file = training_params.get("profiling_file", "timeline.json") + self.full_validator = None + self.ema_full_validator = None + validating_params = config.get("validating") or {} self.full_validator = self._create_full_validator( validating_params=validating_params, validation_data=validation_data, ) + self.ema_full_validator = self._create_ema_full_validator( + validating_params=validating_params, + validation_data=validation_data, + ) # Log model parameter count if self.rank == 0: @@ -984,7 +1022,7 @@ def _create_full_validator( validation_data: DpLoaderSet | None, ) -> FullValidator | None: """Create the runtime full validator when it is active.""" - if not self._is_full_validation_requested(validating_params): + if not self._is_validation_requested(validating_params, "full_validation"): return None self._raise_if_full_validation_unsupported(validation_data) if validation_data is None: @@ -995,7 +1033,7 @@ def _create_full_validator( validating_params=validating_params, validation_data=validation_data, model=self.model, - train_infos=self._get_inner_module().train_infos, + state_store=self._get_inner_module().train_infos, num_steps=self.num_steps, rank=self.rank, zero_stage=self.zero_stage, @@ -1003,9 +1041,53 @@ def _create_full_validator( checkpoint_dir=Path(self.save_ckpt).parent, ) - def _is_full_validation_requested(self, validating_params: dict[str, Any]) -> bool: - """Check whether full validation can trigger during this training run.""" - if not validating_params.get("full_validation", False): + def _create_ema_full_validator( + self, + *, + validating_params: dict[str, Any], + validation_data: DpLoaderSet | None, + ) -> FullValidator | None: + """Create the runtime EMA full validator when it is active.""" + if not self._is_validation_requested( + validating_params, "full_validation" + ) or not validating_params.get("ema_full_validation", False): + return None + self._raise_if_full_validation_unsupported(validation_data) + if self.model_ema is None: + raise ValueError( + "validating.ema_full_validation requires `training.enable_ema=true`." + ) + if validation_data is None: + raise RuntimeError( + "validation_data must be available after EMA full validation checks." + ) + ema_validating_params = dict(validating_params) + ema_validating_params["full_validation"] = True + return FullValidator( + validating_params=ema_validating_params, + validation_data=validation_data, + model=self.model, + state_store=self.model_ema.validation_state, + num_steps=self.num_steps, + rank=self.rank, + zero_stage=self.zero_stage, + restart_training=self.restart_training, + checkpoint_dir=Path(self.save_ckpt).parent, + full_val_file=get_ema_validation_log_path( + validating_params.get("full_val_file", "val.log") + ), + best_checkpoint_prefix="best_ema.ckpt", + emit_best_save_log=False, + model_eval_context=lambda: self.model_ema.apply_shadow(self.model), + ) + + def _is_validation_requested( + self, + validating_params: dict[str, Any], + flag_name: str, + ) -> bool: + """Check whether a full validation flow can trigger during this run.""" + if not validating_params.get(flag_name, False): return False start_step = resolve_full_validation_start_step( validating_params.get("full_val_start", 0.5), @@ -1319,6 +1401,9 @@ def fake_model() -> dict: else: raise ValueError(f"Not supported optimizer type '{self.opt_type}'") + if self.model_ema is not None: + self.model_ema.update(self.model) + if self.disp_avg: # Accumulate loss for averaging over display interval self.step_count_in_interval += 1 @@ -1558,6 +1643,13 @@ def log_loss_valid(_task_key: str = "Default") -> dict: lr=cur_lr, save_checkpoint=self.save_model, ) + if self.ema_full_validator is not None: + self.ema_full_validator.run( + step_id=_step_id, + display_step=display_step_id, + lr=cur_lr, + save_checkpoint=self.save_ema_model, + ) if ( ( @@ -1574,6 +1666,16 @@ def log_loss_valid(_task_key: str = "Default") -> dict: symlink_prefix_files(self.latest_model.stem, self.save_ckpt) with open("checkpoint", "w") as f: f.write(str(self.latest_model)) + if self.model_ema is not None: + self.latest_ema_model = Path( + self.ema_save_ckpt + f"-{display_step_id}.pt" + ) + self.save_ema_model(self.latest_ema_model, lr=cur_lr, step=_step_id) + if self.rank == 0 or dist.get_rank() == 0: + symlink_prefix_files( + self.latest_ema_model.stem, + self.ema_save_ckpt, + ) # tensorboard if self.enable_tensorboard and ( @@ -1653,11 +1755,24 @@ def log_loss_valid(_task_key: str = "Default") -> dict: symlink_prefix_files(self.latest_model.stem, self.save_ckpt) with open("checkpoint", "w") as f: f.write(str(self.latest_model)) + if self.model_ema is not None: + self.latest_ema_model = Path( + self.ema_save_ckpt + f"-{self.num_steps}.pt" + ) + self.save_ema_model( + self.latest_ema_model, + lr=cur_lr, + step=self.num_steps - 1, + ) + symlink_prefix_files(self.latest_ema_model.stem, self.ema_save_ckpt) if self.num_steps == 0 and self.zero_stage > 0: # ZeRO-1 / FSDP: all ranks participate in save_model (collective op) self.latest_model = Path(self.save_ckpt + "-0.pt") self.save_model(self.latest_model, lr=0, step=0) + if self.model_ema is not None: + self.latest_ema_model = Path(self.ema_save_ckpt + "-0.pt") + self.save_ema_model(self.latest_ema_model, lr=0, step=0) if ( self.rank == 0 or dist.get_rank() == 0 @@ -1667,10 +1782,15 @@ def log_loss_valid(_task_key: str = "Default") -> dict: # When num_steps is 0, the checkpoint is never saved in the loop self.latest_model = Path(self.save_ckpt + "-0.pt") self.save_model(self.latest_model, lr=0, step=0) + if self.model_ema is not None: + self.latest_ema_model = Path(self.ema_save_ckpt + "-0.pt") + self.save_ema_model(self.latest_ema_model, lr=0, step=0) log.info(f"Saved model to {self.latest_model}") symlink_prefix_files(self.latest_model.stem, self.save_ckpt) with open("checkpoint", "w") as f: f.write(str(self.latest_model)) + if self.model_ema is not None: + symlink_prefix_files(self.latest_ema_model.stem, self.ema_save_ckpt) if self.timing_in_training and self.timed_steps: msg = f"average training time: {self.total_train_time / self.timed_steps:.4f} s/batch" @@ -1707,48 +1827,172 @@ def log_loss_valid(_task_key: str = "Default") -> dict: f"The profiling trace has been saved to: {self.profiling_file}" ) - def save_model(self, save_path: str, lr: float = 0.0, step: int = 0) -> None: + def _collect_checkpoint_states( + self, + *, + use_ema_weights: bool = False, + include_optimizer: bool = True, + ) -> tuple[dict[str, Any], dict[str, Any] | None]: + """Collect model and optimizer states for checkpointing. + + Parameters + ---------- + use_ema_weights : bool + If True, temporarily swap in EMA shadow weights before collecting + the model state dict. + include_optimizer : bool + If False, skip collecting the optimizer state. Used for EMA ckpts + where the optimizer state is meaningless: EMA has no optimizer of + its own, and the main model's optimizer state corresponds to a + different parameter trajectory. + + Returns + ------- + tuple[dict[str, Any], dict[str, Any] | None] + (model_state, optim_state). optim_state is None when + include_optimizer is False. + """ module = self._get_inner_module() - module.train_infos["lr"] = float(lr) - module.train_infos["step"] = step + ema_context = ( + self.model_ema.apply_shadow(self.model) + if use_ema_weights and self.model_ema is not None + else nullcontext() + ) + with ema_context: + if self.zero_stage >= 2: + # FSDP2: collective op, all ranks participate; rank 0 gets full state + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + model_state = get_model_state_dict(self.wrapper, options=options) + optim_state = ( + get_optimizer_state_dict( + self.wrapper, self.optimizer, options=options + ) + if include_optimizer + else None + ) + elif self.zero_stage == 1: + # ZeRO-1: consolidate sharded optimizer state to rank 0. + model_state = module.state_dict() + if use_ema_weights: + # state_dict() tensors share storage with live parameters; clone + # them before the EMA context restores the original weights. + model_state = deepcopy(model_state) + if include_optimizer: + self.optimizer.consolidate_state_dict(to=0) + optim_state = self.optimizer.state_dict() if self.rank == 0 else {} + else: + optim_state = None + else: + model_state = module.state_dict() + if use_ema_weights: + # Same storage-sharing issue as zero_stage == 1. + model_state = deepcopy(model_state) + optim_state = self.optimizer.state_dict() if include_optimizer else None + return model_state, optim_state - # === Collect state dicts === - if self.zero_stage >= 2: - # FSDP2: collective op, all ranks participate; rank 0 gets full state - options = StateDictOptions(full_state_dict=True, cpu_offload=True) - model_state = get_model_state_dict(self.wrapper, options=options) - optim_state = get_optimizer_state_dict( - self.wrapper, self.optimizer, options=options - ) - elif self.zero_stage == 1: - # ZeRO-1: consolidate sharded optimizer state to rank 0 - model_state = module.state_dict() - self.optimizer.consolidate_state_dict(to=0) - optim_state = ( - deepcopy(self.optimizer.state_dict()) if self.rank == 0 else {} - ) - else: - model_state = module.state_dict() - optim_state = deepcopy(self.optimizer.state_dict()) + @staticmethod + def _parse_checkpoint_step(path: Path, prefix_name: str) -> int | None: + """Parse the checkpoint step from ``-.pt`` filenames.""" + checkpoint_prefix = f"{prefix_name}-" + if path.suffix != ".pt" or not path.name.startswith(checkpoint_prefix): + return None + step_text = path.name[len(checkpoint_prefix) : -len(path.suffix)] + if not step_text.isdigit(): + return None + return int(step_text) + + def _write_checkpoint( + self, + save_path: Path, + checkpoint_data: dict[str, Any], + *, + ckpt_prefix: str, + max_ckpt_keep: int, + ) -> None: + """Write a checkpoint file and apply prefix-based cleanup.""" + prefix_name = Path(ckpt_prefix).name # === Only rank 0 writes to disk === if self.rank != 0: return - for item in optim_state["param_groups"]: - item["lr"] = float(item["lr"]) - torch.save( - {"model": model_state, "optimizer": optim_state}, + optim_state = checkpoint_data.get("optimizer") + if optim_state is not None: + for item in optim_state["param_groups"]: + item["lr"] = float(item["lr"]) + torch.save(checkpoint_data, save_path) + checkpoint_dir = save_path.parent + checkpoint_files = [] + for checkpoint_file in checkpoint_dir.glob("*.pt"): + step = self._parse_checkpoint_step(checkpoint_file, prefix_name) + if checkpoint_file.is_symlink() or step is None: + continue + checkpoint_files.append((checkpoint_file, step)) + + current_step = self._parse_checkpoint_step(save_path, prefix_name) + if current_step is not None: + fresh_checkpoint_files = [] + for checkpoint_file, step in checkpoint_files: + if step > current_step: + checkpoint_file.unlink() + else: + fresh_checkpoint_files.append((checkpoint_file, step)) + checkpoint_files = fresh_checkpoint_files + + checkpoint_files.sort(key=lambda item: (item[1], item[0].name)) + while len(checkpoint_files) > max_ckpt_keep: + checkpoint_files.pop(0)[0].unlink() + + def save_model( + self, + save_path: str | Path, + lr: float = 0.0, + step: int = 0, + *, + ckpt_prefix: str | None = None, + max_ckpt_keep: int | None = None, + use_ema_weights: bool = False, + include_ema_state: bool = True, + include_optimizer: bool = True, + ) -> None: + module = self._get_inner_module() + module.train_infos["lr"] = float(lr) + module.train_infos["step"] = step + model_state, optim_state = self._collect_checkpoint_states( + use_ema_weights=use_ema_weights, + include_optimizer=include_optimizer, + ) + checkpoint_data: dict[str, Any] = {"model": model_state} + if optim_state is not None: + checkpoint_data["optimizer"] = optim_state + if include_ema_state and self.model_ema is not None and self.rank == 0: + checkpoint_data[EMA_CHECKPOINT_KEY] = self.model_ema.state_dict() + self._write_checkpoint( + Path(save_path), + checkpoint_data, + ckpt_prefix=self.save_ckpt if ckpt_prefix is None else ckpt_prefix, + max_ckpt_keep=( + self.max_ckpt_keep if max_ckpt_keep is None else max_ckpt_keep + ), + ) + + def save_ema_model( + self, save_path: str | Path, lr: float = 0.0, step: int = 0 + ) -> None: + """Save an EMA-weight checkpoint using the regular checkpoint format.""" + if self.model_ema is None: + raise ValueError( + "EMA checkpoint saving requires `training.enable_ema=true`." + ) + self.save_model( save_path, + lr=lr, + step=step, + ckpt_prefix=self.ema_save_ckpt, + max_ckpt_keep=self.ema_ckpt_keep, + use_ema_weights=True, + include_ema_state=False, + include_optimizer=False, ) - checkpoint_dir = save_path.parent - checkpoint_files = [ - f - for f in checkpoint_dir.glob("*.pt") - if not f.is_symlink() and f.name.startswith(self.save_ckpt) - ] - if len(checkpoint_files) > self.max_ckpt_keep: - checkpoint_files.sort(key=lambda x: x.stat().st_mtime) - checkpoint_files[0].unlink() def get_data( self, is_train: bool = True, task_key: str = "Default" diff --git a/deepmd/pt/train/validation.py b/deepmd/pt/train/validation.py index fa3d2b364c..ace3fb244d 100644 --- a/deepmd/pt/train/validation.py +++ b/deepmd/pt/train/validation.py @@ -8,6 +8,9 @@ import logging import re import traceback +from contextlib import ( + nullcontext, +) from dataclasses import ( dataclass, ) @@ -63,6 +66,7 @@ if TYPE_CHECKING: from collections.abc import ( + Callable, Iterator, ) @@ -81,14 +85,13 @@ TOPK_RECORDS_INFO_KEY = "full_validation_topk_records" BEST_METRIC_NAME_INFO_KEY = "full_validation_metric" -BEST_CKPT_GLOB = "best.ckpt-*.t-*.pt" -BEST_CKPT_PATTERN = re.compile(r"^best\.ckpt-(\d+)\.t-(\d+)\.pt$") STALE_FULL_VALIDATION_INFO_KEYS = ( "full_validation_best_metric", "full_validation_best_step", "full_validation_best_path", "full_validation_best_records", ) +BEST_CKPT_PREFIX = "best.ckpt" VAL_LOG_SIGNIFICANT_DIGITS = 5 VAL_LOG_COLUMN_GAP = " " VAL_LOG_HEADER_PREFIX = "# " @@ -119,6 +122,16 @@ class BestCheckpointRecord: step: int +def build_best_checkpoint_glob(best_checkpoint_prefix: str) -> str: + """Build the glob pattern for managed best checkpoints.""" + return f"{best_checkpoint_prefix}-*.t-*.pt" + + +def build_best_checkpoint_pattern(best_checkpoint_prefix: str) -> re.Pattern[str]: + """Build the regex pattern for managed best checkpoints.""" + return re.compile(rf"^{re.escape(best_checkpoint_prefix)}-(\d+)\.t-(\d+)\.pt$") + + def parse_validation_metric(metric: str) -> tuple[str, str]: """Parse the configured full validation metric.""" normalized_metric = normalize_full_validation_metric(metric) @@ -187,22 +200,39 @@ def __init__( validating_params: dict[str, Any], validation_data: Any, model: torch.nn.Module, - train_infos: dict[str, Any], + state_store: dict[str, Any], num_steps: int, rank: int, zero_stage: int, restart_training: bool, checkpoint_dir: Path | None = None, + full_val_file: str | Path | None = None, + best_checkpoint_prefix: str = BEST_CKPT_PREFIX, + metric_name_info_key: str = BEST_METRIC_NAME_INFO_KEY, + topk_records_info_key: str = TOPK_RECORDS_INFO_KEY, + stale_state_keys: tuple[str, ...] = STALE_FULL_VALIDATION_INFO_KEYS, + emit_best_save_log: bool = True, + model_eval_context: Callable[[], Any] | None = None, ) -> None: self.validation_data = validation_data self.model = model - self.train_infos = train_infos + self.state_store = state_store self.rank = rank self.zero_stage = zero_stage self.checkpoint_dir = ( Path(checkpoint_dir) if checkpoint_dir is not None else Path(".") ) self.is_distributed = dist.is_available() and dist.is_initialized() + self.metric_name_info_key = metric_name_info_key + self.topk_records_info_key = topk_records_info_key + self.stale_state_keys = stale_state_keys + self.best_checkpoint_prefix = best_checkpoint_prefix + self.best_checkpoint_glob = build_best_checkpoint_glob(best_checkpoint_prefix) + self.best_checkpoint_pattern = build_best_checkpoint_pattern( + best_checkpoint_prefix + ) + self.emit_best_save_log = emit_best_save_log + self.model_eval_context = model_eval_context or nullcontext self.full_validation = bool(validating_params.get("full_validation", False)) self.validation_freq = int(validating_params.get("validation_freq", 5000)) @@ -211,7 +241,12 @@ def __init__( self.metric_name, self.metric_key = parse_validation_metric( str(validating_params.get("validation_metric", "E:MAE")) ) - self.full_val_file = Path(validating_params.get("full_val_file", "val.log")) + resolved_log_file = ( + full_val_file + if full_val_file is not None + else validating_params.get("full_val_file", "val.log") + ) + self.full_val_file = Path(resolved_log_file) self.start_step = resolve_full_validation_start_step( validating_params.get("full_val_start", 0.5), num_steps, @@ -236,7 +271,7 @@ def __init__( ) self.topk_records = self._load_topk_records() - self._sync_train_infos() + self._sync_state_store() if self.rank == 0: self._initialize_best_checkpoints(restart_training=restart_training) @@ -335,8 +370,9 @@ def _evaluate(self, display_step: int) -> FullValidationResult: was_training = bool(getattr(self.model, "training", True)) self.model.eval() try: - # === Step 2. Evaluate All Systems === - metrics = self.evaluate_all_systems() + with self.model_eval_context(): + # === Step 2. Evaluate All Systems === + metrics = self.evaluate_all_systems() finally: self.model.train(was_training) @@ -584,27 +620,27 @@ def _update_best_state( return None self.topk_records = updated_records - self._sync_train_infos() + self._sync_state_store() if not self.save_best: return None candidate_rank = self.topk_records.index(candidate) + 1 return str(self._best_checkpoint_path(display_step, candidate_rank)) - def _sync_train_infos(self) -> None: - """Synchronize top-K validation state into train infos.""" - for key in STALE_FULL_VALIDATION_INFO_KEYS: - self.train_infos.pop(key, None) - self.train_infos[BEST_METRIC_NAME_INFO_KEY] = self.metric_name - self.train_infos[TOPK_RECORDS_INFO_KEY] = [ + def _sync_state_store(self) -> None: + """Synchronize top-K validation state into the configured state store.""" + for key in self.stale_state_keys: + self.state_store.pop(key, None) + self.state_store[self.metric_name_info_key] = self.metric_name + self.state_store[self.topk_records_info_key] = [ {"metric": record.metric, "step": record.step} for record in self.topk_records ] def _load_topk_records(self) -> list[BestCheckpointRecord]: - """Load top-K records from train infos for the current metric.""" - if self.train_infos.get(BEST_METRIC_NAME_INFO_KEY) != self.metric_name: + """Load top-K records from the configured state store.""" + if self.state_store.get(self.metric_name_info_key) != self.metric_name: return [] - raw_records = self.train_infos.get(TOPK_RECORDS_INFO_KEY, []) + raw_records = self.state_store.get(self.topk_records_info_key, []) if not isinstance(raw_records, list): return [] records = [] @@ -624,7 +660,7 @@ def _load_topk_records(self) -> list[BestCheckpointRecord]: def _best_checkpoint_name(self, step: int, rank: int) -> str: """Build the best-checkpoint filename for one step.""" - return f"best.ckpt-{step}.t-{rank}.pt" + return f"{self.best_checkpoint_prefix}-{step}.t-{rank}.pt" def _best_checkpoint_path(self, step: int, rank: int) -> Path: """Build the best-checkpoint path for one step.""" @@ -634,7 +670,7 @@ def _list_best_checkpoints(self) -> list[Path]: """List all managed best checkpoints in the checkpoint directory.""" best_checkpoints = [ path - for path in self.checkpoint_dir.glob(BEST_CKPT_GLOB) + for path in self.checkpoint_dir.glob(self.best_checkpoint_glob) if path.is_file() and not path.is_symlink() ] best_checkpoints.sort(key=lambda path: path.stat().st_mtime) @@ -654,7 +690,7 @@ def _reconcile_best_checkpoints(self) -> None: files_by_step: dict[int, list[Path]] = {} stale_files: list[Path] = [] for checkpoint_path in current_files: - match = BEST_CKPT_PATTERN.match(checkpoint_path.name) + match = self.best_checkpoint_pattern.match(checkpoint_path.name) if match is None: stale_files.append(checkpoint_path) continue @@ -722,7 +758,7 @@ def _log_result(self, result: FullValidationResult | None) -> None: if result is None: raise ValueError("Full validation logging requires a result on rank 0.") self._write_log_file(result) - if result.saved_best_path is not None: + if self.emit_best_save_log and result.saved_best_path is not None: metric_label, metric_value, metric_unit = format_metric_for_log( self.metric_name, result.selected_metric_value ) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index e1795afe16..e33f61b039 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3946,6 +3946,21 @@ def training_args( "The oldest checkpoints will be deleted once the number of checkpoints exceeds max_ckpt_keep. " "Defaults to 5." ) + doc_enable_ema = ( + "Whether to maintain an exponential moving average (EMA) of model " + "parameters during training and save periodic EMA checkpoints with an " + "`_ema` suffix in the checkpoint prefix." + ) + doc_ema_decay = ( + "The decay factor used for the exponential moving average of model " + "parameters. The EMA update is " + "`ema = ema_decay * ema + (1 - ema_decay) * param`." + ) + doc_ema_ckpt_keep = ( + "The maximum number of periodic EMA checkpoints to keep. " + "EMA checkpoints use the same prefix-based cleanup rule as regular " + "training checkpoints, but with an EMA-specific checkpoint prefix." + ) doc_change_bias_after_training = ( "Whether to change the output bias after the last training step, " "by performing predictions using trained model on training data and " @@ -4059,6 +4074,31 @@ def training_args( "save_ckpt", str, optional=True, default="model.ckpt", doc=doc_save_ckpt ), Argument("max_ckpt_keep", int, optional=True, default=5, doc=doc_max_ckpt_keep), + Argument( + "enable_ema", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_enable_ema, + ), + Argument( + "ema_decay", + float, + optional=True, + default=0.999, + doc=doc_only_pt_supported + doc_ema_decay, + extra_check=lambda x: 0.0 <= x < 1.0, + extra_check_errmsg="must be greater than or equal to 0 and less than 1", + ), + Argument( + "ema_ckpt_keep", + int, + optional=True, + default=3, + doc=doc_only_pt_supported + doc_ema_ckpt_keep, + extra_check=lambda x: x > 0, + extra_check_errmsg="must be greater than 0", + ), Argument( "change_bias_after_training", bool, @@ -4153,6 +4193,11 @@ def training_extra_check(data: dict | None) -> bool: num_epoch = data.get("numb_epoch") num_epoch_dict = data.get("num_epoch_dict", {}) model_prob = data.get("model_prob", {}) + zero_stage = int(data.get("zero_stage", 0)) + if data.get("enable_ema", False) and zero_stage >= 2: + raise ValueError( + "training.enable_ema currently only supports training.zero_stage < 2." + ) if multi_task: if num_epoch is not None: raise ValueError( @@ -4254,6 +4299,15 @@ def validating_args() -> Argument: "Whether to save an extra checkpoint when the selected full validation " "metric reaches a new best value." ) + doc_ema_full_validation = ( + "Whether to additionally run the same full validation flow on the " + "EMA-smoothed model when `validating.full_validation=true`. This reuses " + "the existing full validation schedule, metric, start step, and " + "best-checkpoint settings, writes results to an EMA-specific validation " + "log such as `val_ema.log`, and saves EMA best checkpoints with a " + "`best_ema.ckpt` prefix. Requires " + "`training.enable_ema=true`." + ) doc_max_best_ckpt = ( "The maximum number of top-ranked best checkpoints to keep. The best " "checkpoints are ranked by the selected validation metric in ascending " @@ -4284,6 +4338,13 @@ def validating_args() -> Argument: default=False, doc=doc_only_pt_supported + doc_full_validation, ), + Argument( + "ema_full_validation", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_ema_full_validation, + ), Argument( "validation_freq", int, @@ -4355,9 +4416,17 @@ def validate_full_validation_config( ) -> None: """Validate cross-section constraints for full validation.""" validating = data.get("validating") or {} - training = data.get("training", {}) - if not validating.get("full_validation", False): + training_params = data.get("training", {}) or {} + full_validation_enabled = bool(validating.get("full_validation", False)) + ema_full_validation_enabled = bool(validating.get("ema_full_validation", False)) + if not full_validation_enabled: return + if float(validating.get("full_val_start", 0.0)) == 1.0: + return + if ema_full_validation_enabled and not training_params.get("enable_ema", False): + raise ValueError( + "validating.ema_full_validation requires `training.enable_ema=true`." + ) metric = str(validating.get("validation_metric", "E:MAE")) if not is_valid_full_validation_metric(metric): @@ -4386,13 +4455,13 @@ def validate_full_validation_config( f"training with loss.type='ener'; got loss.type={loss_type!r}." ) - if not training.get("validation_data"): + if not training_params.get("validation_data"): raise ValueError( - "validating.full_validation requires `training.validation_data`. " - "It is only supported for single-task energy training." + "full validation requires `training.validation_data`. It is only " + "supported for single-task energy training." ) - zero_stage = int(training.get("zero_stage", 0)) + zero_stage = int(training_params.get("zero_stage", 0)) if zero_stage >= 2: raise ValueError( "validating.full_validation only supports single-task energy " diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 9e840dd9f2..eeb53265fb 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -1,15 +1,25 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import functools import json import os import shutil +import signal import tempfile import unittest +from collections.abc import ( + Callable, +) from copy import ( deepcopy, ) from pathlib import ( Path, ) +from typing import ( + Any, + TypeVar, + cast, +) from unittest.mock import ( patch, ) @@ -21,6 +31,9 @@ get_trainer, ) from deepmd.pt.entrypoints.main import train as train_entry +from deepmd.pt.train.ema import ( + EMA_CHECKPOINT_KEY, +) from deepmd.pt.utils.finetune import ( get_finetune_rules, ) @@ -44,6 +57,36 @@ model_zbl, ) +_F = TypeVar("_F", bound=Callable[..., Any]) + + +def _training_timeout(seconds: int) -> Callable[[_F], _F]: + """Limit real training tests on platforms that support SIGALRM.""" + + def decorate(func: _F) -> _F: + if not hasattr(signal, "SIGALRM"): + return func + + @functools.wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> Any: + def raise_timeout(signum: int, frame: Any) -> None: + raise TimeoutError(f"training test exceeded {seconds} seconds") + + previous_handler = signal.signal(signal.SIGALRM, raise_timeout) + signal.alarm(seconds) + try: + return func(*args, **kwargs) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, previous_handler) + + return cast("_F", wrapped) + + return decorate + + +TRAINING_TEST_TIMEOUT = _training_timeout(60) + class DPTrainTest: test_zbl_from_standard: bool = False @@ -664,11 +707,9 @@ class TestModelChangeOutBiasFittingStat(unittest.TestCase): """ def test_fitting_stat_consistency(self) -> None: + import deepmd.pt.train.training as training_module from deepmd.pt.model.model import get_model as get_model_pt from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT - from deepmd.pt.train.training import ( - model_change_out_bias, - ) from deepmd.pt.utils.utils import to_numpy_array as torch_to_numpy from deepmd.pt.utils.utils import to_torch_tensor as numpy_to_torch from deepmd.utils.argcheck import model_args as model_args_fn @@ -742,7 +783,7 @@ def test_fitting_stat_consistency(self) -> None: # Model B: use the NEW code path via model_change_out_bias sample_func = lambda: merged # noqa: E731 - model_change_out_bias(model_b, sample_func, "set-by-statistic") + training_module.model_change_out_bias(model_b, sample_func, "set-by-statistic") # Compare out_bias bias_a = torch_to_numpy(model_a.get_out_bias()) @@ -906,5 +947,203 @@ def test_full_validation_rejects_multitask(self) -> None: normalize(config, multi_task=True) +class TestEMATraining(unittest.TestCase): + def setUp(self) -> None: + import deepmd.pt.train.training as training_module + import deepmd.pt.utils.env as env_module + + self._num_workers_state = ( + (env_module, env_module.NUM_WORKERS), + (training_module, training_module.NUM_WORKERS), + ) + env_module.NUM_WORKERS = 0 + training_module.NUM_WORKERS = 0 + self._cwd = os.getcwd() + self._tmpdir = tempfile.TemporaryDirectory() + os.chdir(self._tmpdir.name) + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config = convert_optimizer_v31_to_v32(self.config, warning=False) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_se_e2_a) + self.config["training"]["numb_steps"] = 4 + self.config["training"]["save_freq"] = 1 + self.config["training"]["max_ckpt_keep"] = 3 + self.config["training"]["disp_training"] = False + self.config["training"]["enable_ema"] = True + self.config["training"]["ema_decay"] = 0.9 + self.config["training"]["ema_ckpt_keep"] = 2 + self.config["validating"] = { + "full_validation": False, + "ema_full_validation": False, + "validation_freq": 1, + "save_best": True, + "max_best_ckpt": 1, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + } + + def tearDown(self) -> None: + for module, num_workers in self._num_workers_state: + module.NUM_WORKERS = num_workers + os.chdir(self._cwd) + self._tmpdir.cleanup() + + @TRAINING_TEST_TIMEOUT + def test_ema_checkpoint_rotation(self) -> None: + trainer = get_trainer(deepcopy(self.config)) + ema_prefix = trainer.ema_save_ckpt + Path(f"{ema_prefix}-999.pt").touch() + trainer.run() + + self.assertFalse(Path(f"{ema_prefix}-999.pt").exists()) + self.assertTrue(Path(f"{ema_prefix}-3.pt").exists()) + self.assertTrue(Path(f"{ema_prefix}-4.pt").exists()) + self.assertFalse(Path(f"{ema_prefix}-2.pt").exists()) + self.assertFalse(Path("checkpoint_ema").exists()) + + def test_ema_checkpoint_cleanup_removes_future_steps(self) -> None: + trainer = get_trainer(deepcopy(self.config)) + trainer.ema_ckpt_keep = 10 + ema_prefix = trainer.ema_save_ckpt + Path(f"{ema_prefix}-999.pt").touch() + + trainer.save_ema_model(f"{ema_prefix}-1.pt", lr=0.0, step=0) + + self.assertFalse(Path(f"{ema_prefix}-999.pt").exists()) + self.assertTrue(Path(f"{ema_prefix}-1.pt").exists()) + + @TRAINING_TEST_TIMEOUT + @patch("deepmd.pt.train.training.model_change_out_bias") + def test_ema_checkpoint_keeps_changed_out_bias( + self, mocked_change_out_bias + ) -> None: + def change_out_bias(model, sample_func, _bias_adjust_mode): + model.set_out_bias(model.get_out_bias() + 1.0) + return model + + mocked_change_out_bias.side_effect = change_out_bias + config = deepcopy(self.config) + config["training"]["numb_steps"] = 1 + config["training"]["change_bias_after_training"] = True + trainer = get_trainer(config) + trainer.run() + + regular_checkpoint = torch.load( + trainer.latest_model, map_location="cpu", weights_only=True + ) + ema_checkpoint = torch.load( + trainer.latest_ema_model, map_location="cpu", weights_only=True + ) + regular_out_bias = { + key: value + for key, value in regular_checkpoint["model"].items() + if key.endswith("out_bias") + } + ema_out_bias = { + key: value + for key, value in ema_checkpoint["model"].items() + if key.endswith("out_bias") + } + + self.assertTrue(regular_out_bias) + self.assertEqual(regular_out_bias.keys(), ema_out_bias.keys()) + for key, regular_value in regular_out_bias.items(): + torch.testing.assert_close(regular_value, ema_out_bias[key]) + + def test_ema_rejects_zero_stage_2_during_normalization(self) -> None: + config = deepcopy(self.config) + config["training"]["zero_stage"] = 2 + config = update_deepmd_input(config, warning=False) + with self.assertRaisesRegex(ValueError, "training.zero_stage < 2"): + normalize(config) + + @TRAINING_TEST_TIMEOUT + def test_restart_restores_ema_state(self) -> None: + trainer = get_trainer(deepcopy(self.config)) + first_key = next(iter(trainer.model_ema.shadow_params)) + trainer.model_ema.shadow_params[first_key].add_(1.2345) + trainer.model_ema.validation_state.update( + { + "full_validation_metric": "e:mae", + "full_validation_topk_records": [ + {"metric": 0.5, "step": 3}, + ], + } + ) + trainer.save_model("model.ckpt-0.pt", lr=0.1, step=0) + + checkpoint = torch.load( + "model.ckpt-0.pt", map_location="cpu", weights_only=True + ) + self.assertIn(EMA_CHECKPOINT_KEY, checkpoint) + + restarted = get_trainer( + deepcopy(self.config), + restart_model="model.ckpt-0.pt", + ) + torch.testing.assert_close( + restarted.model_ema.shadow_params[first_key], + trainer.model_ema.shadow_params[first_key], + ) + self.assertEqual( + restarted.model_ema.validation_state, + trainer.model_ema.validation_state, + ) + + @TRAINING_TEST_TIMEOUT + @patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") + def test_ema_full_validation_writes_separate_outputs(self, mocked_eval) -> None: + mocked_eval.side_effect = [ + {"mae_e_per_atom": 10.0}, + {"mae_e_per_atom": 1.0}, + {"mae_e_per_atom": 10.0}, + {"mae_e_per_atom": 0.5}, + {"mae_e_per_atom": 10.0}, + {"mae_e_per_atom": 0.75}, + {"mae_e_per_atom": 10.0}, + {"mae_e_per_atom": 0.25}, + ] + config = deepcopy(self.config) + config["validating"]["full_validation"] = True + config["validating"]["ema_full_validation"] = True + trainer = get_trainer(config) + trainer.run() + + self.assertTrue(Path("val.log").exists()) + self.assertTrue(Path("val_ema.log").exists()) + self.assertTrue(Path("best_ema.ckpt-4.t-1.pt").exists()) + self.assertFalse(Path("best_ema.ckpt-1.t-1.pt").exists()) + train_infos = trainer._get_inner_module().train_infos + self.assertNotIn("full_validation_ema_metric", train_infos) + self.assertNotIn("full_validation_ema_topk_records", train_infos) + self.assertEqual( + trainer.model_ema.validation_state["full_validation_topk_records"], + [{"metric": 0.25, "step": 4}], + ) + + @TRAINING_TEST_TIMEOUT + @patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") + def test_ema_full_validation_ignored_without_full_validation( + self, mocked_eval + ) -> None: + config = deepcopy(self.config) + config["training"]["enable_ema"] = False + config["validating"]["full_validation"] = False + config["validating"]["ema_full_validation"] = True + trainer = get_trainer(config) + trainer.run() + + mocked_eval.assert_not_called() + self.assertFalse(Path("val.log").exists()) + self.assertFalse(Path("val_ema.log").exists()) + self.assertIsNone(trainer.model_ema) + self.assertIsNone(trainer.ema_full_validator) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/test_validation.py b/source/tests/pt/test_validation.py index eb4e32b065..2bb19dce4a 100644 --- a/source/tests/pt/test_validation.py +++ b/source/tests/pt/test_validation.py @@ -196,7 +196,7 @@ def test_full_validator_rotates_best_checkpoint(self) -> None: }, validation_data=_DummyValidationData(), model=_DummyModel(), - train_infos=train_infos, + state_store=train_infos, num_steps=10, rank=0, zero_stage=0, @@ -266,7 +266,7 @@ def test_full_validator_restores_top_k_checkpoints(self) -> None: }, validation_data=_DummyValidationData(), model=_DummyModel(), - train_infos=train_infos, + state_store=train_infos, num_steps=10, rank=0, zero_stage=0,