Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions hackable_diffusion/lib/architecture/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def _dot_product_attention(
rescale: Float["..."],
*,
mask: Bool["batch sequence_key"] | None = None,
dropout_rate: float = 0.0,
is_training: bool = True,
) -> Float["batch sequence_query head*dim"]:
"""Performs dot product attention.

Expand All @@ -137,6 +139,8 @@ def _dot_product_attention(
rescale: Rescale factor for the attention scores.
mask: Mask tensor. Mask is True for tokens we want to keep and False for
tokens we want to mask. If None, no masking is performed.
dropout_rate: The dropout rate for the attention weights.
is_training: Whether the model is in training mode.

Returns:
The output tensor.
Expand All @@ -156,6 +160,11 @@ def _dot_product_attention(
# Softmax and attention weights
attn_weights = _stable_softmax(logits=attn_logits)

if dropout_rate > 0.0:
attn_weights = nn.Dropout(rate=dropout_rate)(
attn_weights, deterministic=not is_training
)

# Calculate attention output
attn_output = jnp.einsum("bhts,bhsd->bhtd", attn_weights, v)

Expand Down Expand Up @@ -194,6 +203,7 @@ class MultiHeadAttention(nn.Module):
use_rope is True.
zero_init_output: If True, the kernel of the final output projection layer
is initialized to zeros.
dropout_rate: The dropout rate for the attention weights.
dtype: The data type of the computation.
"""

Expand All @@ -203,6 +213,7 @@ class MultiHeadAttention(nn.Module):
use_rope: bool = False
rope_position_type: RoPEPositionType = RoPEPositionType.SQUARE
zero_init_output: bool = False
dropout_rate: float = 0.0
dtype: DType = jnp.float32

def setup(self):
Expand All @@ -226,6 +237,7 @@ def __call__(
c: Float["batch sequence2 dim2"] | None,
*,
mask: Bool["batch sequence1|sequence2"] | None = None,
is_training: bool = True,
) -> Float["batch sequence1 dim1"]:
"""Computes multi-head attention.

Expand Down Expand Up @@ -319,6 +331,8 @@ def __call__(
v=v,
rescale=scale,
mask=mask,
dropout_rate=self.dropout_rate,
is_training=is_training,
)

attn_output = nn.Dense(
Expand Down
96 changes: 96 additions & 0 deletions hackable_diffusion/lib/architecture/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,102 @@ def test_multi_head_attention_invalid_mask_shape_raises_error(
):
module.init(self.rng, self.x, c, mask=invalid_mask)

# MARK: Dropout Tests

def test_multi_head_attention_dropout_disabled_during_evaluation(self):
"""Verifies dropout is inactive when is_training=False (evaluation mode)."""
# Initialize with an aggressive dropout rate (e.g., 0.5)
module = attention.MultiHeadAttention(
num_heads=self.num_heads,
dropout_rate=0.5,
)

# Generate random inputs to capture exact matrix values
rng1, rng2 = jax.random.split(self.rng)
x_rand = jax.random.normal(
rng1, (self.batch_size, self.seq_len_q, self.dim)
)

variables = module.init(rng2, x_rand, c=None)

# Run twice with evaluation mode (is_training=False).
# Even with a 50% dropout rate, the outputs should be completely identical.
output_eval_1 = module.apply(variables, x_rand, c=None, is_training=False)
output_eval_2 = module.apply(variables, x_rand, c=None, is_training=False)

np.testing.assert_allclose(
output_eval_1,
output_eval_2,
atol=1e-6,
)

def test_multi_head_attention_dropout_active_during_training(self):
"""Verifies dropout alters outputs randomly when is_training=True."""
module = attention.MultiHeadAttention(
num_heads=self.num_heads,
dropout_rate=0.5,
)

rng1, rng2, rng_dropout1, rng_dropout2 = jax.random.split(self.rng, 4)
x_rand = jax.random.normal(
rng1, (self.batch_size, self.seq_len_q, self.dim)
)

variables = module.init(rng2, x_rand, c=None)

# Flax requires a 'dropout' RNG stream state passed inside a dict
# whenever execution hits an active nn.Dropout layer during training.
output_train_1 = module.apply(
variables,
x_rand,
c=None,
is_training=True,
rngs={"dropout": rng_dropout1},
)
output_train_2 = module.apply(
variables,
x_rand,
c=None,
is_training=True,
rngs={"dropout": rng_dropout2},
)

# Since two distinct keys were injected into the dropout stream,
# different masks were dropped, meaning outputs must differ.
self.assertFalse(jnp.allclose(output_train_1, output_train_2, atol=1e-5))

def test_multi_head_attention_dropout_scales_retained_activations(self):
"""Verifies dropout scales active entries by 1 / (1 - rate) during training."""
# Set a 50% rate. Active entries must double in value (multiplied by 2.0)
rate = 0.5
module = attention.MultiHeadAttention(
num_heads=self.num_heads,
dropout_rate=rate,
)

rng1, rng2, rng_dropout = jax.random.split(self.rng, 3)
x_rand = jax.random.normal(
rng1, (self.batch_size, self.seq_len_q, self.dim)
)

variables = module.init(rng2, x_rand, c=None)

output_eval = module.apply(variables, x_rand, c=None, is_training=False)
output_train = module.apply(
variables,
x_rand,
c=None,
is_training=True,
rngs={"dropout": rng_dropout},
)

# Standard inverted dropout behavior means active values must be larger
# than non-dropped values to preserve target expectation bounds.
max_train_val = float(jnp.max(jnp.abs(output_train)))
max_eval_val = float(jnp.max(jnp.abs(output_eval)))

self.assertGreater(max_train_val, max_eval_val)


if __name__ == "__main__":
absltest.main()
43 changes: 31 additions & 12 deletions hackable_diffusion/lib/architecture/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class NormalizationLayer(nn.Module):
dtype: The data type of the computation.
use_bias: Whether to use bias in the normalization layer.
use_scale: Whether to use scale in the normalization layer.
use_conditional_shift: Whether to use conditional shift in the normalization
layer (only applies when `conditional` is True).
"""

normalization_method: NormalizationType
Expand All @@ -88,6 +90,7 @@ class NormalizationLayer(nn.Module):
dtype: DType = jnp.float32
use_bias: bool = True
use_scale: bool = True
use_conditional_shift: bool = True

def setup(self):
if (
Expand Down Expand Up @@ -169,18 +172,28 @@ def __call__(
)

if self.conditional:

scale_and_shift = nn.Dense(
ch * 2,
kernel_init=nn.zeros_init(),
bias_init=nn.zeros_init(),
dtype=self.dtype,
)(c)
scale, shift = jnp.split(scale_and_shift, 2, axis=-1) # (B, ch) each.

x = einops.rearrange(x, "b ... c -> b c ...") # (B, ch, ...).
scale = jax_helpers.bcast_right(scale, x.ndim)
shift = jax_helpers.bcast_right(shift, x.ndim)
# Scale + shift adaptive conditioning.
if self.use_conditional_shift:
scale_and_shift = nn.Dense(
ch * 2,
kernel_init=nn.zeros_init(),
bias_init=nn.zeros_init(),
dtype=self.dtype,
)(c)
scale, shift = jnp.split(scale_and_shift, 2, axis=-1) # (B, ch) each.
scale = jax_helpers.bcast_right(scale, x.ndim)
shift = jax_helpers.bcast_right(shift, x.ndim)
else:
# Scale-only adaptive conditioning (no shift).
scale = nn.Dense(
ch,
kernel_init=nn.zeros_init(),
bias_init=nn.zeros_init(),
dtype=self.dtype,
)(c)
scale = jax_helpers.bcast_right(scale, x.ndim)
shift = jnp.zeros_like(scale)
x = (1.0 + scale) * x + shift
x = einops.rearrange(x, "b c ... -> b ... c")

Expand Down Expand Up @@ -211,6 +224,8 @@ class NormalizationLayerFactory:
dtype: The data type of the computation.
use_bias: Whether to use bias in the normalization layer.
use_scale: Whether to use scale in the normalization layer.
use_conditional_shift: Whether to use conditional shift in the normalization
layer (only applies when `conditional` is True).
"""

def __init__(
Expand All @@ -221,13 +236,15 @@ def __init__(
dtype: DType = jnp.float32,
use_bias: bool = True,
use_scale: bool = True,
use_conditional_shift: bool = True,
):
self.normalization_method = normalization_method
self.epsilon = epsilon
self.num_groups = num_groups
self.dtype = dtype
self.use_bias = use_bias
self.use_scale = use_scale
self.use_conditional_shift = use_conditional_shift

def unconditional_norm(
self, norm_name: str = "UnconditionalNorm"
Expand All @@ -242,12 +259,13 @@ def unconditional_norm(
dtype=self.dtype,
use_bias=self.use_bias,
use_scale=self.use_scale,
use_conditional_shift=self.use_conditional_shift,
)

def conditional_norm(
self, norm_name: str = "ConditionalNorm"
) -> NormalizationLayer:
"""Returns a factory for creating conditional normalization layers."""
"""Returns a conditional normalization layer."""
return NormalizationLayer(
normalization_method=self.normalization_method,
conditional=True,
Expand All @@ -257,4 +275,5 @@ def conditional_norm(
dtype=self.dtype,
use_bias=self.use_bias,
use_scale=self.use_scale,
use_conditional_shift=self.use_conditional_shift,
)
Loading
Loading