diff --git a/hackable_diffusion/kdiff/core.py b/hackable_diffusion/kdiff/core.py index 13c2b2c..edfb749 100644 --- a/hackable_diffusion/kdiff/core.py +++ b/hackable_diffusion/kdiff/core.py @@ -63,11 +63,11 @@ Conditioning = hd_typing.Conditioning DataTree = hd_typing.DataTree LossOutput = hd_typing.LossOutput +PyTree = hd_typing.PyTree TargetInfo = hd_typing.TargetInfo TimeArray = hd_typing.TimeArray typechecked = hd_typing.typechecked - ################################################################################ # MARK: Diffusion ################################################################################ @@ -113,7 +113,7 @@ def __call__( self, x0: DataTree, cond: Conditioning | None = None, - ) -> dict[str, dict[str, Array] | Array]: + ) -> dict[str, PyTree]: """Run the diffusion training step. Samples timesteps, corrupts the input data according to the corruption