diff --git a/hackable_diffusion/lib/multimodal.py b/hackable_diffusion/lib/multimodal.py index 38110ce..3896442 100644 --- a/hackable_diffusion/lib/multimodal.py +++ b/hackable_diffusion/lib/multimodal.py @@ -649,9 +649,7 @@ def __call__( xt_rescaled = xt # 2. Encode conditioning - conditioning_embeddings = cast(nn.Module, self.conditioning_encoder).copy( - name='ConditioningEncoder' - )( + conditioning_embeddings = self.conditioning_encoder( time=time_rescaled, conditioning=conditioning, is_training=is_training, @@ -674,9 +672,7 @@ def _create_zero_logits(xt_leaf, process_leaf): lambda x, z: jnp.concatenate([x, z], axis=-1), xt_rescaled, zero_logits ) - backbone_module = cast(nn.Module, self.backbone_network).copy( - name='Backbone' - ) + backbone_module = self.backbone_network first_output = backbone_module( x=xt_with_zeros,