Conversation
| _C.MODEL.ARCHITECTURE = "unet" | ||
| # Architecture of the network. Possible values are: | ||
| # * 'patchgan' | ||
| _C.MODEL.ARCHITECTURE_D = "patchgan" |
There was a problem hiding this comment.
This is a feature only of nafnet so introduce it inside MODEL.NAFNET
| _C.MODEL.NAFNET.FFN_EXPAND = 2 | ||
|
|
||
| # Discriminator PATCHGAN | ||
| _C.MODEL.PATCHGAN = CN() |
There was a problem hiding this comment.
Move MODEL.PATCHGAN inside MODEL.NAFNET
| _C.LOSS.COMPOSED_GAN.GAMMA_SSIM = 1.0 | ||
|
|
||
| # Backward-compatible alias for previous naming. | ||
| _C.LOSS.GAN = CN() |
There was a problem hiding this comment.
What's the purpose of this? It shouldn't be necessary
| # Optimizer to use. Possible values: "SGD", "ADAM" or "ADAMW" | ||
| _C.TRAIN.OPTIMIZER = "SGD" | ||
| # Optimizer to use. Possible values: "SGD", "ADAM" or "ADAMW" for GAN discriminator | ||
| _C.TRAIN.OPTIMIZER_D = "SGD" |
There was a problem hiding this comment.
Now that the more than one opt is used change TRAIN.OPTIMIZER to be a list of str. For all the model only one opt should be required but for the GAN at hand
| # Learning rate | ||
| _C.TRAIN.LR = 1.0e-4 | ||
| # Learning rate for GAN discriminator | ||
| _C.TRAIN.LR_D = 1.0e-4 |
There was a problem hiding this comment.
Same as with optimizers: TRAIN.LR should be converted to a list now. Check that the opts and lrs must have the same lenght in check_configuration.py
|
|
||
| return model, str(callable_model.__name__), collected_sources, all_import_lines, scanned_files, args, network_stride # type: ignore | ||
|
|
||
| def build_discriminator(cfg: CN, device: torch.device): |
There was a problem hiding this comment.
this should be inside nafnet
| @@ -0,0 +1,165 @@ | |||
| import torch | |||
There was a problem hiding this comment.
Comment the functions and extend the descriptions. Check other models to see how we do it and try to do it in the same way.
| @@ -0,0 +1,23 @@ | |||
| import torch.nn as nn | |||
There was a problem hiding this comment.
same as with nafnet.py
| jobname, | ||
| epoch, | ||
| model_without_ddp, | ||
| optimizer, |
There was a problem hiding this comment.
this should be a list now. Loop over them to store and remove optimizer_d
| Optional discriminator model to include in checkpoints for GAN training. | ||
| optimizer_d : Optional[torch.optim.Optimizer], optional | ||
| Optional discriminator optimizer state to include in checkpoints for GAN training. | ||
| extra_checkpoint_items : Optional[dict], optional |
There was a problem hiding this comment.
what is this? It is necessary? if not remove it please
No description provided.