Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,44 @@ The `notebooks/` directory contains several tutorials to get you started:
* **`mnist_discrete.ipynb`**: An example of discrete diffusion.
* **`mnist_multimodal.ipynb`**: A showcase of the multimodal capabilities, generating images and labels jointly.

## Training configs

The `kdiff/configs/` directory contains example configurations for training:

* **`mnist_unet.py`**: Standard diffusion training configuration on MNIST.

To run a config locally, create a small launcher script (e.g. `train.py`):

```python
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import multiprocessing
from kauldron import konfig

def main():
import importlib.util
spec = importlib.util.spec_from_file_location(
"config", "kdiff/configs/mnist_unet.py"
)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)

cfg = config_module.get_config()
cfg.workdir = "/tmp/mnist_workdir"
trainer = konfig.resolve(cfg)
trainer.train()

if __name__ == "__main__":
multiprocessing.set_start_method("spawn", force=True)
main()
```

> **Note:** `XLA_PYTHON_CLIENT_PREALLOCATE=false` must be set *before*
> importing JAX to prevent GPU memory preallocation conflicts with data
> loading workers. The `if __name__ == "__main__"` guard is required for
> multiprocessing compatibility.

## Installation

To install the necessary dependencies, you can use pip with the provided
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Simple diffusion config for training a UNet on ImageNet 64x64.

Should reach about FID=4.1 within 500k steps, and FID=2.7 within 2M steps.
r"""Simple diffusion config for training a UNet on MNIST.
"""

from kauldron import konfig
Expand All @@ -26,7 +24,6 @@
from jax.experimental import checkify
from hackable_diffusion import hd
from hackable_diffusion.kdiff import core
from hackable_diffusion.kdiff import data
from hackable_diffusion.kdiff import evals


Expand All @@ -40,21 +37,27 @@ def get_config():
cfg.seed = 1337

cfg.aux = konfig.ConfigDict()
cfg.aux.cond_embedding_dim = 192
cfg.aux.cond_embedding_dim = 128

##############################################################################
# MARK: Corruption Process
##############################################################################

corruption_process = hd.corruption.GaussianProcess(
schedule=hd.corruption.RFSchedule(),
)

# MARK: Conditioning
##############################################################################
# MARK: Conditioning networks
##############################################################################

time_encoder = hd.architecture.SinusoidalTimeEmbedder(
activation="gelu",
num_features=cfg.ref.aux.cond_embedding_dim,
embedding_dim=cfg.ref.aux.cond_embedding_dim,
)
label_encoder = hd.architecture.LabelEmbedder(
num_classes=1000,
num_classes=10,
num_features=cfg.ref.aux.cond_embedding_dim,
conditioning_key="label",
)
Expand All @@ -65,17 +68,20 @@ def get_config():
},
merge_embeddings_fn=hd.architecture.SumEmbeddings(),
conditioning_rules={
"label": 'adaptive_norm',
"time": 'adaptive_norm',
"label": "adaptive_norm",
"time": "adaptive_norm",
},
conditioning_dropout_rate=0.1,
)

# MARK: Backbone
##############################################################################
# MARK: Backbone network
##############################################################################

backbone_network = hd.architecture.Unet(
base_channels=192,
channels_multiplier=(1, 2, 3, 4),
num_residual_blocks=(3, 3, 3, 3),
base_channels=64,
channels_multiplier=(1, 2, 2, 2),
num_residual_blocks=(2, 2, 2, 2),
downsample_fn=hd.architecture.AvgPoolDownsample(),
upsample_fn=hd.architecture.ImageResizeUpsample(resize_method="nearest"),
dropout_rate=(0.0, 0.1, 0.1, 0.1),
Expand All @@ -93,7 +99,9 @@ def get_config():
skip_connection_fn=hd.architecture.UnnormalizedAddSkip(),
)

# MARK: Model / Diffusion
##############################################################################
# MARK: Diffusion model
##############################################################################

cfg.model = core.Diffusion(
x0="batch.image",
Expand All @@ -109,12 +117,18 @@ def get_config():
),
)

##############################################################################
# MARK: Training
cfg.num_train_steps = 3_000_000
##############################################################################

cfg.num_train_steps = 100_000

cfg.train_ds = _make_ds(training=True, batch_size=512)
cfg.train_ds = _make_ds(training=True, batch_size=256)

##############################################################################
# MARK: Losses
##############################################################################

cfg.train_losses = {
"diffusion_loss": core.KauldronLossWrapper(
loss=hd.training.SiD2Loss(
Expand All @@ -125,23 +139,29 @@ def get_config():
),
}

##############################################################################
# MARK: Optimizer
##############################################################################

cfg.schedules = {
"learning_rate": optax.warmup_constant_schedule(
init_value=0.0,
peak_value=3e-5,
warmup_steps=10_000,
peak_value=3e-4,
warmup_steps=1_000,
)
}

cfg.optimizer = kd.optim.named_chain(**{
"clip": optax.clip_by_global_norm(max_norm=1.0),
"adam": optax.scale_by_adam(b1=0.9, b2=0.99, eps=1e-12),
"lr": optax.scale_by_learning_rate(cfg.ref.schedules["learning_rate"]),
"ema": kd.optim.ema_params(decay=0.9999),
"ema": kd.optim.ema_params(decay=0.999),
})

#############################################################################
# MARK: Metrics
##############################################################################

cfg.train_metrics = {
"grad_norm": kd.metrics.SkipIfMissing(
kd.metrics.TreeReduce(
Expand All @@ -152,61 +172,58 @@ def get_config():
),
}
cfg.train_summaries = {
"overview": kd.contrib.summaries.ImageGrid.simple(
columns={
"gt": "batch.image",
"x0_pred": "preds.output.x0",
"xt": "preds.xt",
},
in_vrange=(-1.0, 1.0),
"gt": kd.summaries.ShowImages(
images="batch.image", in_vrange=(-1.0, 1.0)
),
"x0_pred": kd.summaries.ShowImages(
images="preds.output.x0", in_vrange=(-1.0, 1.0)
),
"xt": kd.summaries.ShowImages(images="preds.xt", in_vrange=(-1.0, 1.0)),
}

##############################################################################
# MARK: Evals
cfg.eval_ds = _make_ds(training=False, batch_size=512)
##############################################################################

cfg.eval_ds = _make_ds(training=False, batch_size=256)

cfg.evals = {
"sample_DDIM": evals.SamplingEvaluator(
run=kd.evals.EveryNSteps(50_000, skip_first=True),
init_transform=kd.optim.UseEmaParams(),
run=kd.evals.EveryNSteps(10_000, skip_first=True),
num_batches=None,
sampler=hd.sampling.DiffusionSampler(
time_schedule=hd.sampling.UniformTimeSchedule(),
stepper=hd.sampling.DDIMStep(
stoch_coeff=0.25,
stoch_coeff=0.0,
corruption_process=cfg.ref.model.corruption_process,
),
num_steps=250,
num_steps=50,
),
metrics={
# TODO(klausg): FID with stats from training set.
# TODO(klausg): opensource FID metric.
"fid": kd.metrics.Fid(
pred="samples.xt",
target="batch.image",
in_vrange=(-1.0, 1.0),
),
},
metrics={},
summaries={
"overview": kd.contrib.summaries.ImageGrid.simple(
columns={
"gt": "batch.image",
"sample": "samples.xt",
},
in_vrange=(-1.0, 1.0),
num_images=10,
)
"gt": kd.summaries.ShowImages(
images="batch.image", in_vrange=(-1.0, 1.0)
),
"sample": kd.summaries.ShowImages(
images="samples.xt", in_vrange=(-1.0, 1.0)
),
},
),
}

##############################################################################
# MARK: Checkpointer
##############################################################################

cfg.checkpointer = kd.ckpts.Checkpointer(
save_interval_steps=10_000,
max_to_keep=3,
)

##############################################################################
# MARK: Other
##############################################################################

# hackable diffusion requires checkify to be activated.
cfg.checkify_error_categories = checkify.user_checks
# Set up random streams.
Expand All @@ -219,24 +236,28 @@ def get_config():
return cfg


# MARK: _make_ds
################################################################################
# MARK: Make dataset
################################################################################


def _make_ds(training: bool, batch_size: int, split: str | None = None):
"""Imagenet 64x64 dataset."""
"""MNIST dataset."""
transforms = [
kd.data.Elements(keep=["image", "label"]),
kd.data.tf.Resize(key="image", height=64, width=64, method="area"),
kd.data.py.Resize(key="image", size=(32, 32), method="bilinear"),
kd.data.ValueRange(key="image", in_vrange=(0, 255), vrange=(-1, 1)),
kd.data.Rearrange(key="label", pattern="... -> ... 1"),
]
if training:
transforms.append(kd.data.tf.RandomFlipLeftRight(key="image"))
# No random flip for MNIST as digits are not flip-invariant
pass

if split is None:
split = "train" if training else "validation"
split = "train" if training else "test"

return kd.data.tf.Tfds(
name="imagenet2012",
decoders={"image": data.decode_and_central_square_crop()},
return kd.data.py.Tfds(
name="mnist",
split=split,
shuffle=True if training else False,
num_epochs=None if training else 1,
Expand Down
2 changes: 1 addition & 1 deletion hackable_diffusion/kdiff/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def from_model_and_context(
################################################################################


class SamplingEvaluator(kd.contrib.evals.CheckpointedEvaluator):
class SamplingEvaluator(kd.evals.Evaluator):
"""Evaluator that samples from the model.

Attributes:
Expand Down
Loading