diff --git a/hackable_diffusion/lib/architecture/arch_typing.py b/hackable_diffusion/lib/architecture/arch_typing.py index 09fd006..0afe105 100644 --- a/hackable_diffusion/lib/architecture/arch_typing.py +++ b/hackable_diffusion/lib/architecture/arch_typing.py @@ -61,6 +61,7 @@ class NormalizationType(enum.StrEnum): RMS_NORM = "rms_norm" GROUP_NORM = "group_norm" LAYER_NORM = "layer_norm" + ADA_RMS_NORM = "ada_rms_norm" class DownsampleType(enum.StrEnum): diff --git a/hackable_diffusion/lib/architecture/normalization.py b/hackable_diffusion/lib/architecture/normalization.py index 6f91d7e..f26fada 100644 --- a/hackable_diffusion/lib/architecture/normalization.py +++ b/hackable_diffusion/lib/architecture/normalization.py @@ -18,6 +18,7 @@ - RMSNorm: https://arxiv.org/abs/1910.07467 - GroupNorm: https://arxiv.org/abs/1803.08494 - LayerNorm: https://arxiv.org/abs/1607.06450 +- AdaRMSNorm: RMSNorm with scale-only adaptive conditioning (no shift). """ import einops @@ -135,6 +136,15 @@ def __call__( feature_axes=-1, # Per channel scale. use_scale=self.use_scale, )(x=x, mask=mask) + elif self.normalization_method == NormalizationType.ADA_RMS_NORM: + # AdaRMSNorm: RMSNorm without learnable scale (scale is adaptive). + x = nn.RMSNorm( + epsilon=self.epsilon, + dtype=self.dtype, + reduction_axes=-1, + feature_axes=-1, + use_scale=False, + )(x=x, mask=mask) elif self.normalization_method == NormalizationType.GROUP_NORM: # If using GroupNorm the mask data must be such that the last dimension @@ -170,19 +180,34 @@ def __call__( if self.conditional: - scale_and_shift = nn.Dense( - ch * 2, - kernel_init=nn.zeros_init(), - bias_init=nn.zeros_init(), - dtype=self.dtype, - )(c) - scale, shift = jnp.split(scale_and_shift, 2, axis=-1) # (B, ch) each. - - x = einops.rearrange(x, "b ... c -> b c ...") # (B, ch, ...). - scale = jax_helpers.bcast_right(scale, x.ndim) - shift = jax_helpers.bcast_right(shift, x.ndim) - x = (1.0 + scale) * x + shift - x = einops.rearrange(x, "b c ... -> b ... c") + if self.normalization_method == NormalizationType.ADA_RMS_NORM: + # Scale-only adaptive conditioning (no shift). + scale = nn.Dense( + ch, + kernel_init=nn.zeros_init(), + bias_init=nn.zeros_init(), + dtype=self.dtype, + )(c) + + x = einops.rearrange(x, "b ... c -> b c ...") # (B, ch, ...). + scale = jax_helpers.bcast_right(scale, x.ndim) + x = (1.0 + scale) * x + x = einops.rearrange(x, "b c ... -> b ... c") + else: + # Scale + shift adaptive conditioning. + scale_and_shift = nn.Dense( + ch * 2, + kernel_init=nn.zeros_init(), + bias_init=nn.zeros_init(), + dtype=self.dtype, + )(c) + scale, shift = jnp.split(scale_and_shift, 2, axis=-1) # (B, ch) each. + + x = einops.rearrange(x, "b ... c -> b c ...") # (B, ch, ...). + scale = jax_helpers.bcast_right(scale, x.ndim) + shift = jax_helpers.bcast_right(shift, x.ndim) + x = (1.0 + scale) * x + shift + x = einops.rearrange(x, "b c ... -> b ... c") return x @@ -247,7 +272,7 @@ def unconditional_norm( def conditional_norm( self, norm_name: str = "ConditionalNorm" ) -> NormalizationLayer: - """Returns a factory for creating conditional normalization layers.""" + """Returns a conditional normalization layer.""" return NormalizationLayer( normalization_method=self.normalization_method, conditional=True, diff --git a/hackable_diffusion/lib/architecture/normalization_test.py b/hackable_diffusion/lib/architecture/normalization_test.py index 54263e5..07b61ad 100644 --- a/hackable_diffusion/lib/architecture/normalization_test.py +++ b/hackable_diffusion/lib/architecture/normalization_test.py @@ -468,6 +468,77 @@ def test_rmsnorm_mask_equivalence(self): ), ) + def test_conditional_ada_rmsnorm_at_init(self): + """Tests AdaRMSNorm at init: scale=0 so output equals unconditional RMSNorm.""" + norm_layer = normalization.NormalizationLayer( + normalization_method=NormalizationType.ADA_RMS_NORM, + conditional=True, + ) + params = norm_layer.init(self.rng, self.x, self.c) + output = norm_layer.apply(params, self.x, self.c) + self.assertEqual(output.shape, self.x_shape) + + # At init, scale=0, so output should match plain RMSNorm (no scale param). + from jax import lax # pylint: disable=g-import-not-at-top,reimported + + x2 = jnp.mean(self.x**2, -1, keepdims=True) + output_ref = self.x * lax.rsqrt(x2 + norm_layer.epsilon) + np.testing.assert_allclose(output, output_ref, rtol=1e-5, atol=1e-5) + + def test_conditional_ada_rmsnorm_perturbed(self): + """Tests AdaRMSNorm with perturbed params produces different output.""" + norm_layer = normalization.NormalizationLayer( + normalization_method=NormalizationType.ADA_RMS_NORM, + conditional=True, + ) + params = norm_layer.init(self.rng, self.x, self.c) + params_perturbed = _perturb_params(params=params, key=self.rng) + output_perturbed = norm_layer.apply(params_perturbed, self.x, self.c) + + # Compute unconditional RMSNorm for comparison. + from jax import lax # pylint: disable=g-import-not-at-top,reimported + + x2 = jnp.mean(self.x**2, -1, keepdims=True) + output_ref = self.x * lax.rsqrt(x2 + norm_layer.epsilon) + + self.assertEqual(output_perturbed.shape, self.x_shape) + self.assertFalse( + np.allclose(output_perturbed, output_ref, rtol=1e-5, atol=1e-5), + msg=( + "AdaRMSNorm output should differ from unconditional after" + " perturbing params." + ), + ) + + def test_ada_rmsnorm_scale_only_no_shift(self): + """Tests that AdaRMSNorm projects to ch (not ch*2) — scale-only.""" + norm_layer = normalization.NormalizationLayer( + normalization_method=NormalizationType.ADA_RMS_NORM, + conditional=True, + ) + params = norm_layer.init(self.rng, self.x, self.c) + # The Dense layer in ADA_RMS_NORM should project to ch (not ch * 2). + dense_kernel = params["params"]["Dense_0"]["kernel"] + expected_shape = (self.c_shape[-1], self.x_shape[-1]) # (cond_dim, ch) + self.assertEqual(dense_kernel.shape, expected_shape) + + def test_ada_rmsnorm_padding_invariance(self): + """Tests AdaRMSNorm padding invariance (same as RMSNorm).""" + norm_layer = normalization.NormalizationLayer( + normalization_method=NormalizationType.ADA_RMS_NORM, + conditional=False, + ) + params = norm_layer.init(self.rng, self.x_small) + params_perturbed = _perturb_params(params=params, key=self.rng) + + out_small = norm_layer.apply(params_perturbed, self.x_small) + out_large = norm_layer.apply(params_perturbed, self.x_large) + np.testing.assert_allclose( + out_small[:, :, : self.unpadded_seq_len, :], + out_large[:, :, : self.unpadded_seq_len, :], + atol=1e-5, + ) + if __name__ == "__main__": absltest.main()