Skip to content

Grayscale Image Support in Real-ESRGAN (1 Channel) Needs Manual Fixes #744

@mhach06

Description

@mhach06

Problem

The RealESRGANPairedDataset in Real-ESRGAN currently does not support grayscale images (1 input/output channel) out of the box.

Even after modifying the num_in_ch and num_out_ch values in the .yml config file (network_g and network_d), the code fails when working with grayscale datasets.

The root cause is the img2tensor() function in BasicSR/basicsr/utils/img_util.py.


Note

This only works if both num_in_ch and num_out_ch are set to 1.
If either is not 1, the code will crash (which is acceptable for most grayscale workflows).


Proposed Fix

The following changes (3 files) allow Real-ESRGAN to train/validate on grayscale datasets without errors.


Change 1: BasicSR/basicsr/utils/img_util.py

-def img2tensor(imgs, bgr2rgb=True, float32=True):
+def img2tensor(imgs, bgr2rgb=True, float32=True, grayscale=False):

-    def _totensor(img, bgr2rgb, float32):
+    def _totensor(img, bgr2rgb, float32, grayscale):
         if img.ndim == 3 and img.shape[2] == 3:
             if img.dtype == 'float64':
                 img = img.astype('float32')
             if bgr2rgb:
                 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+            if grayscale:
+                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+
+        if img.ndim == 2:
+            img = img[:, :, None]
 
-    if isinstance(imgs, list):
-        return [_totensor(img, bgr2rgb, float32) for img in imgs]
-    else:
-        return _totensor(imgs, bgr2rgb, float32)
+    if isinstance(imgs, list):
+        return [_totensor(img, bgr2rgb, float32, grayscale) for img in imgs]
+    else:
+        return _totensor(imgs, bgr2rgb, float32, grayscale)

Change 2: BasicSR/basicsr/data/realesrgan_paired_dataset.py

-        # BGR to RGB, HWC to CHW, numpy to tensor
-        img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=False, float32=True)
+        # BGR to RGB, HWC to CHW, numpy to tensor (grayscale if in_chans=1)
+        grayscale = True if self.opt['network_g']['num_in_ch'] == 1 else False
+        img_gt, img_lq = img2tensor([img_gt, img_lq],
+                                    bgr2rgb=False,
+                                    float32=True,
+                                    grayscale=grayscale)

Change 3: BasicSR/basicsr/train.py

-    for phase, dataset_opt in opt['datasets'].items():
-        if phase == 'train':
+    for phase, dataset_opt in opt['datasets'].items():
+        dataset_opt['network_g'] = opt['network_g']
+        if phase == 'train':

Result

After applying these changes:

  • Grayscale datasets (1-channel input/output) work without errors.
  • Training and validation proceed as expected.

I can open a PR if this seems useful.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions