Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hackable_diffusion/lib/architecture/arch_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class NormalizationType(enum.StrEnum):
RMS_NORM = "rms_norm"
GROUP_NORM = "group_norm"
LAYER_NORM = "layer_norm"
ADA_RMS_NORM = "ada_rms_norm"


class DownsampleType(enum.StrEnum):
Expand Down
53 changes: 39 additions & 14 deletions hackable_diffusion/lib/architecture/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
- RMSNorm: https://arxiv.org/abs/1910.07467
- GroupNorm: https://arxiv.org/abs/1803.08494
- LayerNorm: https://arxiv.org/abs/1607.06450
- AdaRMSNorm: RMSNorm with scale-only adaptive conditioning (no shift).
"""

import einops
Expand Down Expand Up @@ -135,6 +136,15 @@ def __call__(
feature_axes=-1, # Per channel scale.
use_scale=self.use_scale,
)(x=x, mask=mask)
elif self.normalization_method == NormalizationType.ADA_RMS_NORM:
# AdaRMSNorm: RMSNorm without learnable scale (scale is adaptive).
x = nn.RMSNorm(
epsilon=self.epsilon,
dtype=self.dtype,
reduction_axes=-1,
feature_axes=-1,
use_scale=False,
)(x=x, mask=mask)
elif self.normalization_method == NormalizationType.GROUP_NORM:

# If using GroupNorm the mask data must be such that the last dimension
Expand Down Expand Up @@ -170,19 +180,34 @@ def __call__(

if self.conditional:

scale_and_shift = nn.Dense(
ch * 2,
kernel_init=nn.zeros_init(),
bias_init=nn.zeros_init(),
dtype=self.dtype,
)(c)
scale, shift = jnp.split(scale_and_shift, 2, axis=-1) # (B, ch) each.

x = einops.rearrange(x, "b ... c -> b c ...") # (B, ch, ...).
scale = jax_helpers.bcast_right(scale, x.ndim)
shift = jax_helpers.bcast_right(shift, x.ndim)
x = (1.0 + scale) * x + shift
x = einops.rearrange(x, "b c ... -> b ... c")
if self.normalization_method == NormalizationType.ADA_RMS_NORM:
# Scale-only adaptive conditioning (no shift).
scale = nn.Dense(
ch,
kernel_init=nn.zeros_init(),
bias_init=nn.zeros_init(),
dtype=self.dtype,
)(c)

x = einops.rearrange(x, "b ... c -> b c ...") # (B, ch, ...).
scale = jax_helpers.bcast_right(scale, x.ndim)
x = (1.0 + scale) * x
x = einops.rearrange(x, "b c ... -> b ... c")
else:
# Scale + shift adaptive conditioning.
scale_and_shift = nn.Dense(
ch * 2,
kernel_init=nn.zeros_init(),
bias_init=nn.zeros_init(),
dtype=self.dtype,
)(c)
scale, shift = jnp.split(scale_and_shift, 2, axis=-1) # (B, ch) each.

x = einops.rearrange(x, "b ... c -> b c ...") # (B, ch, ...).
scale = jax_helpers.bcast_right(scale, x.ndim)
shift = jax_helpers.bcast_right(shift, x.ndim)
x = (1.0 + scale) * x + shift
x = einops.rearrange(x, "b c ... -> b ... c")

return x

Expand Down Expand Up @@ -247,7 +272,7 @@ def unconditional_norm(
def conditional_norm(
self, norm_name: str = "ConditionalNorm"
) -> NormalizationLayer:
"""Returns a factory for creating conditional normalization layers."""
"""Returns a conditional normalization layer."""
return NormalizationLayer(
normalization_method=self.normalization_method,
conditional=True,
Expand Down
71 changes: 71 additions & 0 deletions hackable_diffusion/lib/architecture/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,77 @@ def test_rmsnorm_mask_equivalence(self):
),
)

def test_conditional_ada_rmsnorm_at_init(self):
"""Tests AdaRMSNorm at init: scale=0 so output equals unconditional RMSNorm."""
norm_layer = normalization.NormalizationLayer(
normalization_method=NormalizationType.ADA_RMS_NORM,
conditional=True,
)
params = norm_layer.init(self.rng, self.x, self.c)
output = norm_layer.apply(params, self.x, self.c)
self.assertEqual(output.shape, self.x_shape)

# At init, scale=0, so output should match plain RMSNorm (no scale param).
from jax import lax # pylint: disable=g-import-not-at-top,reimported

x2 = jnp.mean(self.x**2, -1, keepdims=True)
output_ref = self.x * lax.rsqrt(x2 + norm_layer.epsilon)
np.testing.assert_allclose(output, output_ref, rtol=1e-5, atol=1e-5)

def test_conditional_ada_rmsnorm_perturbed(self):
"""Tests AdaRMSNorm with perturbed params produces different output."""
norm_layer = normalization.NormalizationLayer(
normalization_method=NormalizationType.ADA_RMS_NORM,
conditional=True,
)
params = norm_layer.init(self.rng, self.x, self.c)
params_perturbed = _perturb_params(params=params, key=self.rng)
output_perturbed = norm_layer.apply(params_perturbed, self.x, self.c)

# Compute unconditional RMSNorm for comparison.
from jax import lax # pylint: disable=g-import-not-at-top,reimported

x2 = jnp.mean(self.x**2, -1, keepdims=True)
output_ref = self.x * lax.rsqrt(x2 + norm_layer.epsilon)

self.assertEqual(output_perturbed.shape, self.x_shape)
self.assertFalse(
np.allclose(output_perturbed, output_ref, rtol=1e-5, atol=1e-5),
msg=(
"AdaRMSNorm output should differ from unconditional after"
" perturbing params."
),
)

def test_ada_rmsnorm_scale_only_no_shift(self):
"""Tests that AdaRMSNorm projects to ch (not ch*2) — scale-only."""
norm_layer = normalization.NormalizationLayer(
normalization_method=NormalizationType.ADA_RMS_NORM,
conditional=True,
)
params = norm_layer.init(self.rng, self.x, self.c)
# The Dense layer in ADA_RMS_NORM should project to ch (not ch * 2).
dense_kernel = params["params"]["Dense_0"]["kernel"]
expected_shape = (self.c_shape[-1], self.x_shape[-1]) # (cond_dim, ch)
self.assertEqual(dense_kernel.shape, expected_shape)

def test_ada_rmsnorm_padding_invariance(self):
"""Tests AdaRMSNorm padding invariance (same as RMSNorm)."""
norm_layer = normalization.NormalizationLayer(
normalization_method=NormalizationType.ADA_RMS_NORM,
conditional=False,
)
params = norm_layer.init(self.rng, self.x_small)
params_perturbed = _perturb_params(params=params, key=self.rng)

out_small = norm_layer.apply(params_perturbed, self.x_small)
out_large = norm_layer.apply(params_perturbed, self.x_large)
np.testing.assert_allclose(
out_small[:, :, : self.unpadded_seq_len, :],
out_large[:, :, : self.unpadded_seq_len, :],
atol=1e-5,
)


if __name__ == "__main__":
absltest.main()
Loading