Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions biapy/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,12 +1162,14 @@ def __init__(self, job_dir: str, job_identifier: str):
# * Semantic segmentation: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2'
# * Instance segmentation: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2'
# * Detection: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2'
# * Denoising: 'unet', 'resunet', 'resunet++', 'attention_unet', 'seunet', 'resunet_se', 'unext_v1', 'unext_v2'
# * Denoising: 'unet', 'resunet', 'resunet++', 'attention_unet', 'seunet', 'resunet_se', 'unext_v1', 'unext_v2', 'nafnet'
# * Super-resolution: 'edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1', 'unext_v2'
# * Self-supervision: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1', 'unext_v2'
# * Classification: 'simple_cnn', 'vit', 'efficientnet_b[0-7]' (only 2D)
# * Image to image: 'edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2'
_C.MODEL.ARCHITECTURE = "unet"
# Architecture of the network. Possible values are:
# * 'patchgan'
# Number of feature maps on each level of the network.
_C.MODEL.FEATURE_MAPS = [16, 32, 64, 128, 256]
# Values to make the dropout with. Set to 0 to prevent dropout. When using it with 'ViT' or 'unetr'
Expand Down Expand Up @@ -1306,6 +1308,27 @@ def __init__(self, job_dir: str, job_identifier: str):
# Whether to use a pretrained version of STUNet on ImageNet
_C.MODEL.STUNET.PRETRAINED = False

# NafNet
_C.MODEL.NAFNET = CN()
# Base number of channels (width) used in the first layer and base levels.
_C.MODEL.NAFNET.WIDTH = 16
# Number of NAFBlocks stacked at the bottleneck (deepest level).
_C.MODEL.NAFNET.MIDDLE_BLK_NUM = 12
# Number of NAFBlocks assigned to each downsampling level of the encoder.
_C.MODEL.NAFNET.ENC_BLK_NUMS = [2, 2, 4, 8]
# Number of NAFBlocks assigned to each upsampling level of the decoder.
_C.MODEL.NAFNET.DEC_BLK_NUMS = [2, 2, 2, 2]
# Channel expansion factor for the depthwise convolution within the gating unit.
_C.MODEL.NAFNET.DW_EXPAND = 2
# Expansion factor for the hidden layer within the feed-forward network.
_C.MODEL.NAFNET.FFN_EXPAND = 2
# Discriminator architecture
_C.MODEL.NAFNET.ARCHITECTURE_D = "patchgan"
# Discriminator PATCHGAN
_C.MODEL.NAFNET.PATCHGAN = CN()
# Number of initial convolutional filters in the first layer of the discriminator.
_C.MODEL.NAFNET.PATCHGAN.BASE_FILTERS = 64

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Loss
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -1371,22 +1394,36 @@ def __init__(self, job_dir: str, job_identifier: str):
_C.LOSS.CONTRAST.MEMORY_SIZE = 5000
_C.LOSS.CONTRAST.PROJ_DIM = 256
_C.LOSS.CONTRAST.PIXEL_UPD_FREQ = 10


# Fine-grained GAN composition. Set any weight to 0.0 to disable that term.
# Used when LOSS.TYPE == "COMPOSED_GAN".
_C.LOSS.COMPOSED_GAN = CN()
# Weight for adversarial BCE term.
_C.LOSS.COMPOSED_GAN.LAMBDA_GAN = 1.0
# Weight for L1 reconstruction term.
_C.LOSS.COMPOSED_GAN.LAMBDA_RECON = 10.0
# Weight for MSE reconstruction term.
_C.LOSS.COMPOSED_GAN.DELTA_MSE = 0.0
# Weight for VGG perceptual term.
_C.LOSS.COMPOSED_GAN.ALPHA_PERCEPTUAL = 0.0
# Weight for SSIM term.
_C.LOSS.COMPOSED_GAN.GAMMA_SSIM = 1.0

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Training phase
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
_C.TRAIN = CN()
_C.TRAIN.ENABLE = False
# Enable verbosity
_C.TRAIN.VERBOSE = False
# Optimizer to use. Possible values: "SGD", "ADAM" or "ADAMW"
_C.TRAIN.OPTIMIZER = "SGD"
# Learning rate
_C.TRAIN.LR = 1.0e-4
# Optimizer(s) to use. Possible values: "SGD", "ADAM" or "ADAMW".
_C.TRAIN.OPTIMIZER = ["SGD"]
# Learning rate(s).
_C.TRAIN.LR = [1.0e-4]
# Weight decay
_C.TRAIN.W_DECAY = 0.02
# Coefficients used for computing running averages of gradient and its square. Used in ADAM and ADAMW optmizers
_C.TRAIN.OPT_BETAS = (0.9, 0.999)
_C.TRAIN.OPT_BETAS = [(0.9, 0.999)]
# Batch size
_C.TRAIN.BATCH_SIZE = 2
# If memory or # gpus is limited, use this variable to maintain the effective batch size, which is
Expand Down
4 changes: 2 additions & 2 deletions biapy/data/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def create_train_val_augmentors(
dic["zflip"] = cfg.AUGMENTOR.ZFLIP
if cfg.PROBLEM.TYPE == "INSTANCE_SEG":
dic["instance_problem"] = True
elif cfg.PROBLEM.TYPE == "DENOISING":
elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.MODEL.ARCHITECTURE != 'nafnet':
dic["n2v"] = True
dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX
dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR
Expand Down Expand Up @@ -293,7 +293,7 @@ def create_train_val_augmentors(
)
if cfg.PROBLEM.TYPE == "INSTANCE_SEG":
dic["instance_problem"] = True
elif cfg.PROBLEM.TYPE == "DENOISING":
elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.MODEL.ARCHITECTURE != 'nafnet':
dic["n2v"] = True
dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX
dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR
Expand Down
114 changes: 70 additions & 44 deletions biapy/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def prepare_optimizer(
cfg: CN,
model_without_ddp: nn.Module | nn.parallel.DistributedDataParallel,
steps_per_epoch: int,
) -> Tuple[Optimizer, Scheduler | None]:
) -> Tuple[list[Optimizer], list[Scheduler | None]]:
"""
Create and configure the optimizer and learning rate scheduler for the given model.

Expand All @@ -40,50 +40,76 @@ def prepare_optimizer(

Returns
-------
optimizer : Optimizer
Configured optimizer for the model.
lr_scheduler : Scheduler or None
Configured learning rate scheduler, or None if not specified.
optimizers : List[Optimizer]
Configured optimizers for the models.
lr_schedulers : List[Scheduler | None]
Configured learning rate schedulers, or None if not specified.
"""
lr = cfg.TRAIN.LR if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR
opt_args = {}
if cfg.TRAIN.OPTIMIZER in ["ADAM", "ADAMW"]:
opt_args["betas"] = cfg.TRAIN.OPT_BETAS
optimizer = timm.optim.create_optimizer_v2(
model_without_ddp,
opt=cfg.TRAIN.OPTIMIZER,
lr=lr,
weight_decay=cfg.TRAIN.W_DECAY,
**opt_args,
)
print(optimizer)

# Learning rate schedulers
lr_scheduler = None
if cfg.TRAIN.LR_SCHEDULER.NAME != "":
if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau":
lr_scheduler = ReduceLROnPlateau(
optimizer,
patience=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_PATIENCE,
factor=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_FACTOR,
min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR,
)
elif cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine":
lr_scheduler = WarmUpCosineDecayScheduler(
lr=cfg.TRAIN.LR,
min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR,
warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS,
epochs=cfg.TRAIN.EPOCHS,
)
elif cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle":
lr_scheduler = OneCycleLR(
optimizer,
cfg.TRAIN.LR,
epochs=cfg.TRAIN.EPOCHS,
steps_per_epoch=steps_per_epoch,
)

return optimizer, lr_scheduler

optimizers = []
lr_schedulers = []

if hasattr(model_without_ddp, 'discriminator') and model_without_ddp.discriminator is not None:
param_groups = [
# Generator
[p for n, p in model_without_ddp.named_parameters() if not n.startswith("discriminator.")], # should this be and p.requires_grad, same below?
# Discriminator
[p for p in model_without_ddp.discriminator.parameters()]
]
else:
param_groups = [[p for p in model_without_ddp.parameters()]]

## Not quite sure if this is the best place to do this
if len(cfg.TRAIN.OPTIMIZER) != len(param_groups):
raise ValueError(
f"Configuration mismatch: You requested {len(cfg.TRAIN.OPTIMIZER)} optimizers, "
f"but the model has {len(param_groups)} parameter group(s). "
f"Check your TRAIN.OPTIMIZER list in the config."
)

for i in range(len(cfg.TRAIN.OPTIMIZER)):
lr = cfg.TRAIN.LR if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else [cfg.TRAIN.LR_SCHEDULER.MIN_LR] * len(cfg.TRAIN.LR)
opt_args = {}
if cfg.TRAIN.OPTIMIZER[i] in ["ADAM", "ADAMW"]:
opt_args["betas"] = cfg.TRAIN.OPT_BETAS[i] if i < len(cfg.TRAIN.OPT_BETAS) else cfg.TRAIN.OPT_BETAS[0]
optimizer = timm.optim.create_optimizer_v2(
param_groups[i],
opt=cfg.TRAIN.OPTIMIZER[i],
lr=lr[i],
weight_decay=cfg.TRAIN.W_DECAY,
**opt_args,
)
print(optimizer)
optimizers.append(optimizer)

# Learning rate schedulers
lr_scheduler = None
if cfg.TRAIN.LR_SCHEDULER.NAME != "":
if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau":
lr_scheduler = ReduceLROnPlateau(
optimizer,
patience=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_PATIENCE,
factor=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_FACTOR,
min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR,
)
elif cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine":
lr_scheduler = WarmUpCosineDecayScheduler(
lr=cfg.TRAIN.LR[i],
min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR,
warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS,
epochs=cfg.TRAIN.EPOCHS,
)
elif cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle":
lr_scheduler = OneCycleLR(
optimizer,
cfg.TRAIN.LR[i],
epochs=cfg.TRAIN.EPOCHS,
steps_per_epoch=steps_per_epoch,
)

lr_schedulers.append(lr_scheduler)

return optimizers, lr_schedulers


def build_callbacks(cfg: CN) -> EarlyStopping | None:
Expand Down
13 changes: 13 additions & 0 deletions biapy/engine/base_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,19 @@ def train(self):
self.plot_values["loss"].append(train_stats["loss"])
if self.val_generator:
self.plot_values["val_loss"].append(test_stats["loss"])
extra_loss_keys = [k for k in train_stats if "loss" in k and k != "loss"]
for loss_key in extra_loss_keys:
val_loss_key = f"val_{loss_key}"

if loss_key not in self.plot_values:
self.plot_values[loss_key] = []
if self.val_generator:
self.plot_values[val_loss_key] = []

# Append the values
self.plot_values[loss_key].append(train_stats[loss_key])
if self.val_generator:
self.plot_values[val_loss_key].append(test_stats.get(loss_key, 0.0))
for i in range(len(self.train_metric_names)):
self.plot_values[self.train_metric_names[i]].append(train_stats[self.train_metric_names[i]])
if self.val_generator:
Expand Down
48 changes: 36 additions & 12 deletions biapy/engine/check_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,7 @@ def sort_key(item):
], "LOSS.CLASS_REBALANCE not in ['none', 'auto'] for INSTANCE_SEG workflow"
elif cfg.PROBLEM.TYPE == "DENOISING":
loss = "MSE" if cfg.LOSS.TYPE == "" else cfg.LOSS.TYPE
assert loss == "MSE", "LOSS.TYPE must be 'MSE'"
assert loss in ["MSE", "COMPOSED_GAN"], "LOSS.TYPE must be in ['MSE', 'COMPOSED_GAN'] for DENOISING"
elif cfg.PROBLEM.TYPE == "CLASSIFICATION":
loss = "CE" if cfg.LOSS.TYPE == "" else cfg.LOSS.TYPE
assert loss == "CE", "LOSS.TYPE must be 'CE'"
Expand Down Expand Up @@ -1797,12 +1797,19 @@ def sort_key(item):

#### Denoising ####
elif cfg.PROBLEM.TYPE == "DENOISING":
if cfg.DATA.TEST.LOAD_GT:
raise ValueError(
"Denoising is made in an unsupervised way so there is no ground truth required. Disable 'DATA.TEST.LOAD_GT'"
)
if not check_value(cfg.PROBLEM.DENOISING.N2V_PERC_PIX):
raise ValueError("PROBLEM.DENOISING.N2V_PERC_PIX not in [0, 1] range")
if cfg.PROBLEM.DENOISING.LOAD_GT_DATA or cfg.LOSS.TYPE == "COMPOSED_GAN":
if not cfg.DATA.TRAIN.GT_PATH and not cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA:
raise ValueError(
"Supervised denoising (e.g., with COMPOSED_GAN or LOAD_GT_DATA=True) "
"requires ground truth. 'DATA.TRAIN.GT_PATH' must be provided."
)
else:
if cfg.DATA.TEST.LOAD_GT:
raise ValueError(
"Denoising is made in an unsupervised way so there is no ground truth required. Disable 'DATA.TEST.LOAD_GT'"
)
if not check_value(cfg.PROBLEM.DENOISING.N2V_PERC_PIX):
raise ValueError("PROBLEM.DENOISING.N2V_PERC_PIX not in [0, 1] range")
if cfg.MODEL.SOURCE == "torchvision":
raise ValueError("'MODEL.SOURCE' as 'torchvision' is not available in denoising workflow")

Expand Down Expand Up @@ -2341,6 +2348,7 @@ def sort_key(item):
"hrnet48",
"hrnet64",
"stunet",
"nafnet",
], "MODEL.ARCHITECTURE not in ['unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'simple_cnn', 'efficientnet_b[0-7]', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1', 'unext_v2', 'hrnet18', 'hrnet32', 'hrnet48', 'hrnet64', 'stunet']"
if (
model_arch
Expand Down Expand Up @@ -2514,6 +2522,7 @@ def sort_key(item):
"hrnet48",
"hrnet64",
"stunet",
"nafnet",
]:
raise ValueError(
"Architectures available for {} are: ['unet', 'resunet', 'resunet++', 'seunet', 'attention_unet', 'resunet_se', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2', 'hrnet18', 'hrnet32', 'hrnet48', 'hrnet64', 'stunet']".format(
Expand Down Expand Up @@ -2714,11 +2723,26 @@ def sort_key(item):
assert cfg.MODEL.OUT_CHECKPOINT_FORMAT in ["pth", "safetensors"], "MODEL.OUT_CHECKPOINT_FORMAT not in ['pth', 'safetensors']"

### Train ###
assert cfg.TRAIN.OPTIMIZER in [
"SGD",
"ADAM",
"ADAMW",
], "TRAIN.OPTIMIZER not in ['SGD', 'ADAM', 'ADAMW']"
if not isinstance(cfg.TRAIN.OPTIMIZER, list):
raise ValueError("'TRAIN.OPTIMIZER' must be a list")
if not isinstance(cfg.TRAIN.LR, list):
raise ValueError("'TRAIN.LR' must be a list")
if not isinstance(cfg.TRAIN.OPT_BETAS, list):
raise ValueError("'TRAIN.OPT_BETAS' must be a list")
if len(cfg.TRAIN.OPTIMIZER) != len(cfg.TRAIN.LR):
raise ValueError("'TRAIN.OPTIMIZER' and 'TRAIN.LR' must have the same length")
print(cfg.TRAIN.OPT_BETAS)
print(len(cfg.TRAIN.OPT_BETAS))
if len(cfg.TRAIN.OPT_BETAS) not in [1, len(cfg.TRAIN.OPTIMIZER)]:
raise ValueError("'TRAIN.OPT_BETAS' must have length 1 or match 'TRAIN.OPTIMIZER' length")

for beta_pair in cfg.TRAIN.OPT_BETAS:
if not isinstance(beta_pair, (list, tuple)) or len(beta_pair) != 2:
raise ValueError("Each entry in 'TRAIN.OPT_BETAS' must be a tuple/list of length 2")

for opt in cfg.TRAIN.OPTIMIZER:
if opt not in ["SGD", "ADAM", "ADAMW"]:
raise ValueError("'TRAIN.OPTIMIZER' values must be in ['SGD', 'ADAM', 'ADAMW']")

if cfg.TRAIN.ENABLE and cfg.TRAIN.LR_SCHEDULER.NAME != "":
if cfg.TRAIN.LR_SCHEDULER.NAME not in [
Expand Down
10 changes: 8 additions & 2 deletions biapy/engine/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from biapy.engine.base_workflow import Base_Workflow
from biapy.data.data_manipulation import save_tif
from biapy.utils.misc import to_pytorch_format, is_main_process, MetricLogger
from biapy.engine.metrics import n2v_loss_mse, loss_encapsulation
from biapy.engine.metrics import n2v_loss_mse, loss_encapsulation, ComposedGANLoss


class Denoising_Workflow(Base_Workflow):
Expand Down Expand Up @@ -166,6 +166,8 @@ def define_metrics(self):
# print("Overriding 'LOSS.TYPE' to set it to N2V loss (masked MSE)")
if self.cfg.LOSS.TYPE == "MSE":
self.loss = loss_encapsulation(n2v_loss_mse)
elif self.cfg.LOSS.TYPE == "COMPOSED_GAN":
self.loss = ComposedGANLoss(cfg=self.cfg, device=self.device)

super().define_metrics()

Expand Down Expand Up @@ -232,7 +234,11 @@ def metric_calculation(

with torch.no_grad():
for i, metric in enumerate(list_to_use):
val = metric(_output.contiguous(), _targets[:, _output.shape[1]:].contiguous())
if _targets.shape[1] == _output.shape[1]:
target_for_metric = _targets.contiguous()
else:
target_for_metric = _targets[:, _output.shape[1]:].contiguous()
val = metric(_output.contiguous(), target_for_metric)
val = val.item() if not torch.isnan(val) else 0
out_metrics[list_names_to_use[i]] = val

Expand Down
Loading