diff --git a/hackable_diffusion/lib/corruption/discrete.py b/hackable_diffusion/lib/corruption/discrete.py index 6cd7fb7..a39a86f 100644 --- a/hackable_diffusion/lib/corruption/discrete.py +++ b/hackable_diffusion/lib/corruption/discrete.py @@ -162,7 +162,7 @@ class CategoricalProcess(CorruptionProcess): """ schedule: DiscreteSchedule - invariant_probs: Sequence[float] + invariant_probs: Sequence[float] = dataclasses.field(repr=False) num_categories: int unused_token: int = UNUSED_TOKEN post_corruption_fn: PostCorruptionFn = IdentityPostCorruptionFn() diff --git a/hackable_diffusion/lib/corruption/discrete_test.py b/hackable_diffusion/lib/corruption/discrete_test.py index 45ece19..dfce81b 100644 --- a/hackable_diffusion/lib/corruption/discrete_test.py +++ b/hackable_diffusion/lib/corruption/discrete_test.py @@ -361,6 +361,14 @@ def test_unused_mask_gives_always_false_on_other_masks(self, process_type): self.assertFalse(is_corrupted_mask[1]) self.assertFalse(is_corrupted_mask[3]) + def test_repr_does_not_include_invariant_probs(self): + process = discrete.CategoricalProcess.uniform_process( + schedule=self.schedule, num_categories=self.num_categories + ) + process_repr = repr(process) + self.assertNotIn('invariant_probs', process_repr) + self.assertIn('num_categories', process_repr) + if __name__ == '__main__': absltest.main()