diff --git a/hackable_diffusion/lib/architecture/attention.py b/hackable_diffusion/lib/architecture/attention.py index 9171f0f..58ccdbc 100644 --- a/hackable_diffusion/lib/architecture/attention.py +++ b/hackable_diffusion/lib/architecture/attention.py @@ -127,6 +127,8 @@ def _dot_product_attention( rescale: Float["..."], *, mask: Bool["batch sequence_key"] | None = None, + dropout_rate: float = 0.0, + is_training: bool = True, ) -> Float["batch sequence_query head*dim"]: """Performs dot product attention. @@ -137,6 +139,8 @@ def _dot_product_attention( rescale: Rescale factor for the attention scores. mask: Mask tensor. Mask is True for tokens we want to keep and False for tokens we want to mask. If None, no masking is performed. + dropout_rate: The dropout rate for the attention weights. + is_training: Whether the model is in training mode. Returns: The output tensor. @@ -156,6 +160,11 @@ def _dot_product_attention( # Softmax and attention weights attn_weights = _stable_softmax(logits=attn_logits) + if dropout_rate > 0.0: + attn_weights = nn.Dropout(rate=dropout_rate)( + attn_weights, deterministic=not is_training + ) + # Calculate attention output attn_output = jnp.einsum("bhts,bhsd->bhtd", attn_weights, v) @@ -194,6 +203,7 @@ class MultiHeadAttention(nn.Module): use_rope is True. zero_init_output: If True, the kernel of the final output projection layer is initialized to zeros. + dropout_rate: The dropout rate for the attention weights. dtype: The data type of the computation. """ @@ -203,6 +213,7 @@ class MultiHeadAttention(nn.Module): use_rope: bool = False rope_position_type: RoPEPositionType = RoPEPositionType.SQUARE zero_init_output: bool = False + dropout_rate: float = 0.0 dtype: DType = jnp.float32 def setup(self): @@ -226,6 +237,7 @@ def __call__( c: Float["batch sequence2 dim2"] | None, *, mask: Bool["batch sequence1|sequence2"] | None = None, + is_training: bool = True, ) -> Float["batch sequence1 dim1"]: """Computes multi-head attention. @@ -319,6 +331,8 @@ def __call__( v=v, rescale=scale, mask=mask, + dropout_rate=self.dropout_rate, + is_training=is_training, ) attn_output = nn.Dense( diff --git a/hackable_diffusion/lib/architecture/attention_test.py b/hackable_diffusion/lib/architecture/attention_test.py index a1c3298..7a31767 100644 --- a/hackable_diffusion/lib/architecture/attention_test.py +++ b/hackable_diffusion/lib/architecture/attention_test.py @@ -408,6 +408,102 @@ def test_multi_head_attention_invalid_mask_shape_raises_error( ): module.init(self.rng, self.x, c, mask=invalid_mask) + # MARK: Dropout Tests + + def test_multi_head_attention_dropout_disabled_during_evaluation(self): + """Verifies dropout is inactive when is_training=False (evaluation mode).""" + # Initialize with an aggressive dropout rate (e.g., 0.5) + module = attention.MultiHeadAttention( + num_heads=self.num_heads, + dropout_rate=0.5, + ) + + # Generate random inputs to capture exact matrix values + rng1, rng2 = jax.random.split(self.rng) + x_rand = jax.random.normal( + rng1, (self.batch_size, self.seq_len_q, self.dim) + ) + + variables = module.init(rng2, x_rand, c=None) + + # Run twice with evaluation mode (is_training=False). + # Even with a 50% dropout rate, the outputs should be completely identical. + output_eval_1 = module.apply(variables, x_rand, c=None, is_training=False) + output_eval_2 = module.apply(variables, x_rand, c=None, is_training=False) + + np.testing.assert_allclose( + output_eval_1, + output_eval_2, + atol=1e-6, + ) + + def test_multi_head_attention_dropout_active_during_training(self): + """Verifies dropout alters outputs randomly when is_training=True.""" + module = attention.MultiHeadAttention( + num_heads=self.num_heads, + dropout_rate=0.5, + ) + + rng1, rng2, rng_dropout1, rng_dropout2 = jax.random.split(self.rng, 4) + x_rand = jax.random.normal( + rng1, (self.batch_size, self.seq_len_q, self.dim) + ) + + variables = module.init(rng2, x_rand, c=None) + + # Flax requires a 'dropout' RNG stream state passed inside a dict + # whenever execution hits an active nn.Dropout layer during training. + output_train_1 = module.apply( + variables, + x_rand, + c=None, + is_training=True, + rngs={"dropout": rng_dropout1}, + ) + output_train_2 = module.apply( + variables, + x_rand, + c=None, + is_training=True, + rngs={"dropout": rng_dropout2}, + ) + + # Since two distinct keys were injected into the dropout stream, + # different masks were dropped, meaning outputs must differ. + self.assertFalse(jnp.allclose(output_train_1, output_train_2, atol=1e-5)) + + def test_multi_head_attention_dropout_scales_retained_activations(self): + """Verifies dropout scales active entries by 1 / (1 - rate) during training.""" + # Set a 50% rate. Active entries must double in value (multiplied by 2.0) + rate = 0.5 + module = attention.MultiHeadAttention( + num_heads=self.num_heads, + dropout_rate=rate, + ) + + rng1, rng2, rng_dropout = jax.random.split(self.rng, 3) + x_rand = jax.random.normal( + rng1, (self.batch_size, self.seq_len_q, self.dim) + ) + + variables = module.init(rng2, x_rand, c=None) + + output_eval = module.apply(variables, x_rand, c=None, is_training=False) + output_train = module.apply( + variables, + x_rand, + c=None, + is_training=True, + rngs={"dropout": rng_dropout}, + ) + + # Standard inverted dropout behavior means active values must be larger + # than non-dropped values to preserve target expectation bounds. + max_train_val = float(jnp.max(jnp.abs(output_train))) + max_eval_val = float(jnp.max(jnp.abs(output_eval))) + + self.assertGreater(max_train_val, max_eval_val) + if __name__ == "__main__": absltest.main() diff --git a/hackable_diffusion/lib/architecture/normalization.py b/hackable_diffusion/lib/architecture/normalization.py index 6f91d7e..756d3a0 100644 --- a/hackable_diffusion/lib/architecture/normalization.py +++ b/hackable_diffusion/lib/architecture/normalization.py @@ -79,6 +79,8 @@ class NormalizationLayer(nn.Module): dtype: The data type of the computation. use_bias: Whether to use bias in the normalization layer. use_scale: Whether to use scale in the normalization layer. + use_conditional_shift: Whether to use conditional shift in the normalization + layer (only applies when `conditional` is True). """ normalization_method: NormalizationType @@ -88,6 +90,7 @@ class NormalizationLayer(nn.Module): dtype: DType = jnp.float32 use_bias: bool = True use_scale: bool = True + use_conditional_shift: bool = True def setup(self): if ( @@ -169,18 +172,28 @@ def __call__( ) if self.conditional: - - scale_and_shift = nn.Dense( - ch * 2, - kernel_init=nn.zeros_init(), - bias_init=nn.zeros_init(), - dtype=self.dtype, - )(c) - scale, shift = jnp.split(scale_and_shift, 2, axis=-1) # (B, ch) each. - x = einops.rearrange(x, "b ... c -> b c ...") # (B, ch, ...). - scale = jax_helpers.bcast_right(scale, x.ndim) - shift = jax_helpers.bcast_right(shift, x.ndim) + # Scale + shift adaptive conditioning. + if self.use_conditional_shift: + scale_and_shift = nn.Dense( + ch * 2, + kernel_init=nn.zeros_init(), + bias_init=nn.zeros_init(), + dtype=self.dtype, + )(c) + scale, shift = jnp.split(scale_and_shift, 2, axis=-1) # (B, ch) each. + scale = jax_helpers.bcast_right(scale, x.ndim) + shift = jax_helpers.bcast_right(shift, x.ndim) + else: + # Scale-only adaptive conditioning (no shift). + scale = nn.Dense( + ch, + kernel_init=nn.zeros_init(), + bias_init=nn.zeros_init(), + dtype=self.dtype, + )(c) + scale = jax_helpers.bcast_right(scale, x.ndim) + shift = jnp.zeros_like(scale) x = (1.0 + scale) * x + shift x = einops.rearrange(x, "b c ... -> b ... c") @@ -211,6 +224,8 @@ class NormalizationLayerFactory: dtype: The data type of the computation. use_bias: Whether to use bias in the normalization layer. use_scale: Whether to use scale in the normalization layer. + use_conditional_shift: Whether to use conditional shift in the normalization + layer (only applies when `conditional` is True). """ def __init__( @@ -221,6 +236,7 @@ def __init__( dtype: DType = jnp.float32, use_bias: bool = True, use_scale: bool = True, + use_conditional_shift: bool = True, ): self.normalization_method = normalization_method self.epsilon = epsilon @@ -228,6 +244,7 @@ def __init__( self.dtype = dtype self.use_bias = use_bias self.use_scale = use_scale + self.use_conditional_shift = use_conditional_shift def unconditional_norm( self, norm_name: str = "UnconditionalNorm" @@ -242,12 +259,13 @@ def unconditional_norm( dtype=self.dtype, use_bias=self.use_bias, use_scale=self.use_scale, + use_conditional_shift=self.use_conditional_shift, ) def conditional_norm( self, norm_name: str = "ConditionalNorm" ) -> NormalizationLayer: - """Returns a factory for creating conditional normalization layers.""" + """Returns a conditional normalization layer.""" return NormalizationLayer( normalization_method=self.normalization_method, conditional=True, @@ -257,4 +275,5 @@ def conditional_norm( dtype=self.dtype, use_bias=self.use_bias, use_scale=self.use_scale, + use_conditional_shift=self.use_conditional_shift, ) diff --git a/hackable_diffusion/lib/architecture/normalization_test.py b/hackable_diffusion/lib/architecture/normalization_test.py index 54263e5..6b84ca7 100644 --- a/hackable_diffusion/lib/architecture/normalization_test.py +++ b/hackable_diffusion/lib/architecture/normalization_test.py @@ -468,6 +468,197 @@ def test_rmsnorm_mask_equivalence(self): ), ) + def test_conditional_rmsnorm_scale_only_at_init(self): + """Tests conditional RMSNorm with scale-only (no shift) at init.""" + norm_layer = normalization.NormalizationLayer( + normalization_method=NormalizationType.RMS_NORM, + conditional=True, + use_conditional_shift=False, + ) + params = norm_layer.init(self.rng, self.x, self.c) + output = norm_layer.apply(params, self.x, self.c) + self.assertEqual(output.shape, self.x_shape) + + # At init, scale=0, so output should match plain RMSNorm. + x2 = jnp.mean(self.x**2, -1, keepdims=True) + output_ref = self.x * lax.rsqrt(x2 + norm_layer.epsilon) + np.testing.assert_allclose(output, output_ref, rtol=1e-5, atol=1e-5) + + def test_conditional_rmsnorm_scale_only_perturbed(self): + """Tests conditional RMSNorm scale-only with perturbed params.""" + norm_layer = normalization.NormalizationLayer( + normalization_method=NormalizationType.RMS_NORM, + conditional=True, + use_conditional_shift=False, + ) + params = norm_layer.init(self.rng, self.x, self.c) + params_perturbed = _perturb_params(params=params, key=self.rng) + output_perturbed = norm_layer.apply(params_perturbed, self.x, self.c) + + # Compute unconditional RMSNorm for comparison. + x2 = jnp.mean(self.x**2, -1, keepdims=True) + output_ref = self.x * lax.rsqrt(x2 + norm_layer.epsilon) + + self.assertEqual(output_perturbed.shape, self.x_shape) + self.assertFalse( + np.allclose(output_perturbed, output_ref, rtol=1e-5, atol=1e-5), + msg=( + 'Scale-only conditional output should differ from unconditional' + ' output after perturbing params.' + ), + ) + + def test_conditional_scale_only_projects_to_ch(self): + """Tests that scale-only conditioning projects to ch (not ch*2).""" + norm_layer = normalization.NormalizationLayer( + normalization_method=NormalizationType.RMS_NORM, + conditional=True, + use_conditional_shift=False, + ) + params = norm_layer.init(self.rng, self.x, self.c) + # The Dense layer should project to ch (not ch * 2). + dense_kernel = params['params']['Dense_0']['kernel'] + expected_shape = (self.c_shape[-1], self.x_shape[-1]) # (cond_dim, ch) + self.assertEqual(dense_kernel.shape, expected_shape) + + def test_conditional_scale_shift_projects_to_ch_times_2(self): + """Tests that scale+shift conditioning projects to ch * 2.""" + norm_layer = normalization.NormalizationLayer( + normalization_method=NormalizationType.RMS_NORM, + conditional=True, + use_conditional_shift=True, + ) + params = norm_layer.init(self.rng, self.x, self.c) + dense_kernel = params['params']['Dense_0']['kernel'] + expected_shape = (self.c_shape[-1], self.x_shape[-1] * 2) + self.assertEqual(dense_kernel.shape, expected_shape) + + def test_conditional_rmsnorm_scale_only_padding_invariance(self): + """Tests scale-only conditional RMSNorm padding invariance.""" + norm_layer = normalization.NormalizationLayer( + normalization_method=NormalizationType.RMS_NORM, + conditional=True, + use_conditional_shift=False, + ) + c_small = jax.random.normal(self.rng, self.c_shape) + params = norm_layer.init(self.rng, self.x_small, c_small) + params_perturbed = _perturb_params(params=params, key=self.rng) + + out_small = norm_layer.apply(params_perturbed, self.x_small, c_small) + out_large = norm_layer.apply(params_perturbed, self.x_large, c_small) + np.testing.assert_allclose( + out_small[:, :, : self.unpadded_seq_len, :], + out_large[:, :, : self.unpadded_seq_len, :], + atol=1e-5, + ) + + @parameterized.named_parameters( + dict( + testcase_name='unconditional_rmsnorm_float32', + normalization_method=NormalizationType.RMS_NORM, + conditional=False, + dtype=jnp.float32, + ), + dict( + testcase_name='unconditional_rmsnorm_bfloat16', + normalization_method=NormalizationType.RMS_NORM, + conditional=False, + dtype=jnp.bfloat16, + ), + dict( + testcase_name='unconditional_layernorm_float32', + normalization_method=NormalizationType.LAYER_NORM, + conditional=False, + dtype=jnp.float32, + ), + dict( + testcase_name='unconditional_layernorm_bfloat16', + normalization_method=NormalizationType.LAYER_NORM, + conditional=False, + dtype=jnp.bfloat16, + ), + dict( + testcase_name='unconditional_groupnorm_float32', + normalization_method=NormalizationType.GROUP_NORM, + conditional=False, + dtype=jnp.float32, + ), + dict( + testcase_name='unconditional_groupnorm_bfloat16', + normalization_method=NormalizationType.GROUP_NORM, + conditional=False, + dtype=jnp.bfloat16, + ), + dict( + testcase_name='conditional_rmsnorm_float32', + normalization_method=NormalizationType.RMS_NORM, + conditional=True, + dtype=jnp.float32, + ), + dict( + testcase_name='conditional_rmsnorm_bfloat16', + normalization_method=NormalizationType.RMS_NORM, + conditional=True, + dtype=jnp.bfloat16, + ), + dict( + testcase_name='conditional_layernorm_float32', + normalization_method=NormalizationType.LAYER_NORM, + conditional=True, + dtype=jnp.float32, + ), + dict( + testcase_name='conditional_layernorm_bfloat16', + normalization_method=NormalizationType.LAYER_NORM, + conditional=True, + dtype=jnp.bfloat16, + ), + dict( + testcase_name='conditional_groupnorm_float32', + normalization_method=NormalizationType.GROUP_NORM, + conditional=True, + dtype=jnp.float32, + ), + dict( + testcase_name='conditional_groupnorm_bfloat16', + normalization_method=NormalizationType.GROUP_NORM, + conditional=True, + dtype=jnp.bfloat16, + ), + dict( + testcase_name='conditional_scale_only_rmsnorm_bfloat16', + normalization_method=NormalizationType.RMS_NORM, + conditional=True, + dtype=jnp.bfloat16, + ), + ) + def test_output_dtype(self, normalization_method, conditional, dtype): + """Tests that the output dtype matches the configured dtype.""" + num_groups = ( + self.num_groups + if (normalization_method == NormalizationType.GROUP_NORM) + else None + ) + + norm_layer = normalization.NormalizationLayer( + normalization_method=normalization_method, + conditional=conditional, + num_groups=num_groups, + dtype=dtype, + ) + + x = self.x.astype(dtype) + if conditional: + c = self.c.astype(dtype) + params = norm_layer.init(self.rng, x, c) + output = norm_layer.apply(params, x, c) + else: + params = norm_layer.init(self.rng, x) + output = norm_layer.apply(params, x) + + self.assertEqual(output.dtype, dtype) + self.assertEqual(output.shape, self.x_shape) + -if __name__ == "__main__": +if __name__ == '__main__': absltest.main()