From e06a8055d09b3a8cd509bf4da8473b6a2c2ebabb Mon Sep 17 00:00:00 2001 From: Valentin De Bortoli Date: Mon, 8 Jun 2026 11:19:39 -0700 Subject: [PATCH] Add MNIST config and fix CheckpointedEvaluator bug. PiperOrigin-RevId: 928672610 --- README.md | 38 +++++ .../{imagenet64_unet.py => mnist_unet.py} | 133 ++++++++++-------- hackable_diffusion/kdiff/evals.py | 2 +- 3 files changed, 116 insertions(+), 57 deletions(-) rename hackable_diffusion/kdiff/configs/{imagenet64_unet.py => mnist_unet.py} (61%) diff --git a/README.md b/README.md index 480274c..c7a059d 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,44 @@ The `notebooks/` directory contains several tutorials to get you started: * **`mnist_discrete.ipynb`**: An example of discrete diffusion. * **`mnist_multimodal.ipynb`**: A showcase of the multimodal capabilities, generating images and labels jointly. +## Training configs + +The `kdiff/configs/` directory contains example configurations for training: + +* **`mnist_unet.py`**: Standard diffusion training configuration on MNIST. + +To run a config locally, create a small launcher script (e.g. `train.py`): + +```python +import os +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + +import multiprocessing +from kauldron import konfig + +def main(): + import importlib.util + spec = importlib.util.spec_from_file_location( + "config", "kdiff/configs/mnist_unet.py" + ) + config_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config_module) + + cfg = config_module.get_config() + cfg.workdir = "/tmp/mnist_workdir" + trainer = konfig.resolve(cfg) + trainer.train() + +if __name__ == "__main__": + multiprocessing.set_start_method("spawn", force=True) + main() +``` + +> **Note:** `XLA_PYTHON_CLIENT_PREALLOCATE=false` must be set *before* +> importing JAX to prevent GPU memory preallocation conflicts with data +> loading workers. The `if __name__ == "__main__"` guard is required for +> multiprocessing compatibility. + ## Installation To install the necessary dependencies, you can use pip with the provided diff --git a/hackable_diffusion/kdiff/configs/imagenet64_unet.py b/hackable_diffusion/kdiff/configs/mnist_unet.py similarity index 61% rename from hackable_diffusion/kdiff/configs/imagenet64_unet.py rename to hackable_diffusion/kdiff/configs/mnist_unet.py index 7135d9e..3a1810b 100644 --- a/hackable_diffusion/kdiff/configs/imagenet64_unet.py +++ b/hackable_diffusion/kdiff/configs/mnist_unet.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -r"""Simple diffusion config for training a UNet on ImageNet 64x64. - -Should reach about FID=4.1 within 500k steps, and FID=2.7 within 2M steps. +r"""Simple diffusion config for training a UNet on MNIST. """ from kauldron import konfig @@ -26,7 +24,6 @@ from jax.experimental import checkify from hackable_diffusion import hd from hackable_diffusion.kdiff import core - from hackable_diffusion.kdiff import data from hackable_diffusion.kdiff import evals @@ -40,21 +37,27 @@ def get_config(): cfg.seed = 1337 cfg.aux = konfig.ConfigDict() - cfg.aux.cond_embedding_dim = 192 + cfg.aux.cond_embedding_dim = 128 + ############################################################################## # MARK: Corruption Process + ############################################################################## + corruption_process = hd.corruption.GaussianProcess( schedule=hd.corruption.RFSchedule(), ) - # MARK: Conditioning + ############################################################################## + # MARK: Conditioning networks + ############################################################################## + time_encoder = hd.architecture.SinusoidalTimeEmbedder( activation="gelu", num_features=cfg.ref.aux.cond_embedding_dim, embedding_dim=cfg.ref.aux.cond_embedding_dim, ) label_encoder = hd.architecture.LabelEmbedder( - num_classes=1000, + num_classes=10, num_features=cfg.ref.aux.cond_embedding_dim, conditioning_key="label", ) @@ -65,17 +68,20 @@ def get_config(): }, merge_embeddings_fn=hd.architecture.SumEmbeddings(), conditioning_rules={ - "label": 'adaptive_norm', - "time": 'adaptive_norm', + "label": "adaptive_norm", + "time": "adaptive_norm", }, conditioning_dropout_rate=0.1, ) - # MARK: Backbone + ############################################################################## + # MARK: Backbone network + ############################################################################## + backbone_network = hd.architecture.Unet( - base_channels=192, - channels_multiplier=(1, 2, 3, 4), - num_residual_blocks=(3, 3, 3, 3), + base_channels=64, + channels_multiplier=(1, 2, 2, 2), + num_residual_blocks=(2, 2, 2, 2), downsample_fn=hd.architecture.AvgPoolDownsample(), upsample_fn=hd.architecture.ImageResizeUpsample(resize_method="nearest"), dropout_rate=(0.0, 0.1, 0.1, 0.1), @@ -93,7 +99,9 @@ def get_config(): skip_connection_fn=hd.architecture.UnnormalizedAddSkip(), ) - # MARK: Model / Diffusion + ############################################################################## + # MARK: Diffusion model + ############################################################################## cfg.model = core.Diffusion( x0="batch.image", @@ -109,12 +117,18 @@ def get_config(): ), ) + ############################################################################## # MARK: Training - cfg.num_train_steps = 3_000_000 + ############################################################################## + + cfg.num_train_steps = 100_000 - cfg.train_ds = _make_ds(training=True, batch_size=512) + cfg.train_ds = _make_ds(training=True, batch_size=256) + ############################################################################## # MARK: Losses + ############################################################################## + cfg.train_losses = { "diffusion_loss": core.KauldronLossWrapper( loss=hd.training.SiD2Loss( @@ -125,12 +139,15 @@ def get_config(): ), } + ############################################################################## # MARK: Optimizer + ############################################################################## + cfg.schedules = { "learning_rate": optax.warmup_constant_schedule( init_value=0.0, - peak_value=3e-5, - warmup_steps=10_000, + peak_value=3e-4, + warmup_steps=1_000, ) } @@ -138,10 +155,13 @@ def get_config(): "clip": optax.clip_by_global_norm(max_norm=1.0), "adam": optax.scale_by_adam(b1=0.9, b2=0.99, eps=1e-12), "lr": optax.scale_by_learning_rate(cfg.ref.schedules["learning_rate"]), - "ema": kd.optim.ema_params(decay=0.9999), + "ema": kd.optim.ema_params(decay=0.999), }) + ############################################################################# # MARK: Metrics + ############################################################################## + cfg.train_metrics = { "grad_norm": kd.metrics.SkipIfMissing( kd.metrics.TreeReduce( @@ -152,61 +172,58 @@ def get_config(): ), } cfg.train_summaries = { - "overview": kd.contrib.summaries.ImageGrid.simple( - columns={ - "gt": "batch.image", - "x0_pred": "preds.output.x0", - "xt": "preds.xt", - }, - in_vrange=(-1.0, 1.0), + "gt": kd.summaries.ShowImages( + images="batch.image", in_vrange=(-1.0, 1.0) + ), + "x0_pred": kd.summaries.ShowImages( + images="preds.output.x0", in_vrange=(-1.0, 1.0) ), + "xt": kd.summaries.ShowImages(images="preds.xt", in_vrange=(-1.0, 1.0)), } + ############################################################################## # MARK: Evals - cfg.eval_ds = _make_ds(training=False, batch_size=512) + ############################################################################## + + cfg.eval_ds = _make_ds(training=False, batch_size=256) cfg.evals = { "sample_DDIM": evals.SamplingEvaluator( - run=kd.evals.EveryNSteps(50_000, skip_first=True), - init_transform=kd.optim.UseEmaParams(), + run=kd.evals.EveryNSteps(10_000, skip_first=True), num_batches=None, sampler=hd.sampling.DiffusionSampler( time_schedule=hd.sampling.UniformTimeSchedule(), stepper=hd.sampling.DDIMStep( - stoch_coeff=0.25, + stoch_coeff=0.0, corruption_process=cfg.ref.model.corruption_process, ), - num_steps=250, + num_steps=50, ), - metrics={ - # TODO(klausg): FID with stats from training set. - # TODO(klausg): opensource FID metric. - "fid": kd.metrics.Fid( - pred="samples.xt", - target="batch.image", - in_vrange=(-1.0, 1.0), - ), - }, + metrics={}, summaries={ - "overview": kd.contrib.summaries.ImageGrid.simple( - columns={ - "gt": "batch.image", - "sample": "samples.xt", - }, - in_vrange=(-1.0, 1.0), - num_images=10, - ) + "gt": kd.summaries.ShowImages( + images="batch.image", in_vrange=(-1.0, 1.0) + ), + "sample": kd.summaries.ShowImages( + images="samples.xt", in_vrange=(-1.0, 1.0) + ), }, ), } + ############################################################################## # MARK: Checkpointer + ############################################################################## + cfg.checkpointer = kd.ckpts.Checkpointer( save_interval_steps=10_000, max_to_keep=3, ) + ############################################################################## # MARK: Other + ############################################################################## + # hackable diffusion requires checkify to be activated. cfg.checkify_error_categories = checkify.user_checks # Set up random streams. @@ -219,24 +236,28 @@ def get_config(): return cfg -# MARK: _make_ds +################################################################################ +# MARK: Make dataset +################################################################################ + + def _make_ds(training: bool, batch_size: int, split: str | None = None): - """Imagenet 64x64 dataset.""" + """MNIST dataset.""" transforms = [ kd.data.Elements(keep=["image", "label"]), - kd.data.tf.Resize(key="image", height=64, width=64, method="area"), + kd.data.py.Resize(key="image", size=(32, 32), method="bilinear"), kd.data.ValueRange(key="image", in_vrange=(0, 255), vrange=(-1, 1)), kd.data.Rearrange(key="label", pattern="... -> ... 1"), ] if training: - transforms.append(kd.data.tf.RandomFlipLeftRight(key="image")) + # No random flip for MNIST as digits are not flip-invariant + pass if split is None: - split = "train" if training else "validation" + split = "train" if training else "test" - return kd.data.tf.Tfds( - name="imagenet2012", - decoders={"image": data.decode_and_central_square_crop()}, + return kd.data.py.Tfds( + name="mnist", split=split, shuffle=True if training else False, num_epochs=None if training else 1, diff --git a/hackable_diffusion/kdiff/evals.py b/hackable_diffusion/kdiff/evals.py index cd15572..3eb58da 100644 --- a/hackable_diffusion/kdiff/evals.py +++ b/hackable_diffusion/kdiff/evals.py @@ -62,7 +62,7 @@ def from_model_and_context( ################################################################################ -class SamplingEvaluator(kd.contrib.evals.CheckpointedEvaluator): +class SamplingEvaluator(kd.evals.Evaluator): """Evaluator that samples from the model. Attributes: