Single image super-resolution (SR) is a classical computer vision problem that aims at recovering a high-resolution image from a lower resolution image. Extensive research was conduct in this area and with the advance of Deep Learning great results have been achieved.

In this post, we will examine one of the Deep Learning approaches to super-resolution called Super-Resolution Convolutional Neural Network (SRCNN). This technique work end to end by extacting patches from the low resolution image and passing them throw convolutional layers to final map them to higher resolution output pixels, as depicted in the diagram below.

We will implement the SRCNN model in TensorFlow, train it and then test it on a low resolution image.

image.png

As an image dataset for training the model, we will be using a Kaggle hosted dataset called Dog and Cat Detection. We will use Kaggle CLI to download this dataset and you need to get your Kaggle API key, alternatively you can manually download the dataset directly from the website.

%%capture
%%bash

pip install -q kaggle

mkdir -p ~/.kaggle
echo '{"username":"KAGGLE_USERNAME","key":"KAGGLE_KEY"}' > ~/.kaggle/kaggle.json
chmod 600 ~/.kaggle/kaggle.json
%%capture
%%bash

kaggle datasets download andrewmvd/dog-and-cat-detection
unzip -q dog-and-cat-detection.zip

To save model checkpoint and other artifacts, we will mount Google Drive the this colab container

from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

Import dependencies

import os
import pathlib
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import *

Set a random seed for reproducibility

SEED = 31
np.random.seed(SEED)

Data

We need a function to resize images based on a scale factor, this will be used later in the process to generate low resolution images from a given image

def resize_image(image_array, factor):
    original_image = Image.fromarray(image_array)
    new_size = np.array(original_image.size) * factor
    new_size = new_size.astype(np.int32)
    new_size = tuple(new_size)
    resized = original_image.resize(new_size)
    resized = img_to_array(resized)
    resized = resized.astype(np.uint8)
    return resized

This function will use the resizing to generate low resolution images by downsizing then upsizing:

def downsize_upsize_image(image, scale):
    scaled = resize_image(image, 1.0 / scale)
    scaled = resize_image(scaled, scale / 1.0)
    return scaled

When we will extract patches, we will slide a window over the original image, and for the image to fit nicely we need to crop it with the following function

def tight_crop_image(image, scale):
    height, width = image.shape[:2]
    width -= int(width % scale)
    height -= int(height % scale)
    return image[:height, :width]

The following function is used to extract patches with a sliding window from an input image. The INPUT_DIM parameter is the height and width of the images as expected by the network

def crop_input(image, x, y):
    y_slice = slice(y, y + INPUT_DIM)
    x_slice = slice(x, x + INPUT_DIM)
    return image[y_slice, x_slice]

Similarly, we need to crop patches from the output images with LABEL_SIZE the height and width of the output of the network. We also need to pad the patches with PAD to make sure we are cropping the regions properly

def crop_output(image, x, y):
    y_slice = slice(y + PAD, y + PAD + LABEL_SIZE)
    x_slice = slice(x + PAD, x + PAD + LABEL_SIZE)
    return image[y_slice, x_slice]

Now let's read all image paths

file_patten = (pathlib.Path('/content') / 'images' / '*.png')
file_pattern = str(file_patten)
dataset_paths = [*glob(file_pattern)]

We don't need the entire dataset as this will take longer training, but will sample around 1000 images from it

SUBSET_SIZE = 1000
dataset_paths = np.random.choice(dataset_paths, SUBSET_SIZE)

Here is an example image from the dataset

path = np.random.choice(dataset_paths)
img = plt.imread(path)
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7fb5cde796d0>

Here we define some parameters, like the scale for resiping, input and output patch sizes, the amount of padding that need to be added to output patches, and the stride which is the number of pixels we'll slide both in the horizontal and vertical axes to extract patches.

SCALE = 2.0
INPUT_DIM = 33
LABEL_SIZE = 21
PAD = int((INPUT_DIM - LABEL_SIZE) / 2.0)
STRIDE = 14

Now, lets build the dataset by reading the input images, generating a low resolution version, sliding a window on this low resolution image as well as the original image to generate patches for training. We will save the patches to disk and later build a training data generator that will load them from disk in batches.

%%bash

mkdir -p data
mkdir -p training
for image_path in tqdm(dataset_paths):
    filename = pathlib.Path(image_path).stem
    image = load_img(image_path)
    image = img_to_array(image)
    image = image.astype(np.uint8)
    image = tight_crop_image(image, SCALE)
    scaled = downsize_upsize_image(image, SCALE)

    height, width = image.shape[:2]

    for y in range(0, height - INPUT_DIM + 1, STRIDE):
        for x in range(0, width - INPUT_DIM + 1, STRIDE):
            crop = crop_input(scaled, x, y)
            target = crop_output(image, x, y)
            np.save(f'data/{filename}_{x}_{y}_input.np', crop)
            np.save(f'data/{filename}_{x}_{y}_output.np', target)
100%|██████████| 1500/1500 [18:00<00:00,  1.39it/s]

We cannot hold all the patches in memory hence we saved to disk in the previous step. Now we need a dataset loader that will load a patch and its label and feed them to the network during traning in batches. This is achieved with the PatchesDataset class (check this example to learn more about generators - link).

class PatchesDataset(tf.keras.utils.Sequence):
    def __init__(self, batch_size, *args, **kwargs):
        self.batch_size = batch_size
        self.input = [*glob('data/*_input.np.npy')]
        self.output = [*glob('data/*_output.np.npy')]
        self.input.sort()
        self.output.sort()
        self.total_data = len(self.input)

    def __len__(self):
        # returns the number of batches
        return int(self.total_data / self.batch_size)

    def __getitem__(self, index):
        # returns one batch
        indices = self.random_indices()
        input = np.array([np.load(self.input[idx]) for idx in indices])
        output = np.array([np.load(self.output[idx]) for idx in indices])
        return input, output

    def random_indices(self):
        return np.random.choice(list(range(self.total_data)), self.batch_size, p=np.ones(self.total_data)/self.total_data)

Define a batch size based on how much memory available on your GPU and create an instance of the dataset generator.

BATCH_SIZE = 1024
train_ds = PatchesDataset(BATCH_SIZE)
len(train_ds)
888

You can see the shape of the training batches

input, output = train_ds[0]
input.shape, output.shape
((1024, 33, 33, 3), (1024, 21, 21, 3))

Model

The architecture of the SRCNN model is very simple, it has only convolutional layers, one to downsize the input and extract image features and a later one to upside to generate the output image. The following helper function is used to create an instance of the model.

def create_model(height, width, depth):
    input = Input(shape=(height, width, depth))
    x = Conv2D(filters=64, kernel_size=(9, 9), kernel_initializer='he_normal')(input)
    x = ReLU()(x)
    x = Conv2D(filters=32, kernel_size=(1, 1), kernel_initializer='he_normal')(x)
    x = ReLU()(x)
    output = Conv2D(filters=depth, kernel_size=(5, 5), kernel_initializer='he_normal')(x)
    return Model(input, output)

To train the network we will use Adam as optimizer with learning rate decay. Also, as the problem we try to train the network for is a regression problem (we want predict the high resolution pixels) we pick MSE as a loss function, this will make the model learn the filters that correctly map patches from low to high resolution.

EPOCHS = 12
optimizer = Adam(learning_rate=1e-3, decay=1e-3 / EPOCHS)
model = create_model(INPUT_DIM, INPUT_DIM, 3)
model.compile(loss='mse', optimizer=optimizer)

You can see how the model is small but astonishly it will be able to achieve great results once trained for enough time, we will train it for 12 epochs

model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 33, 33, 3)]       0         
                                                                 
 conv2d (Conv2D)             (None, 25, 25, 64)        15616     
                                                                 
 re_lu (ReLU)                (None, 25, 25, 64)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 25, 25, 32)        2080      
                                                                 
 re_lu_1 (ReLU)              (None, 25, 25, 32)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 21, 21, 3)         2403      
                                                                 
=================================================================
Total params: 20,099
Trainable params: 20,099
Non-trainable params: 0
_________________________________________________________________
tf.keras.utils.plot_model(model, show_shapes = True, rankdir='LR')

Create a callback that saves the model's weights

checkpoint_path = "training/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1)

Now finally, we can train the network

model.fit(train_ds, epochs=EPOCHS, callbacks=[cp_callback])
Epoch 1/12
888/888 [==============================] - ETA: 0s - loss: 258.0937
Epoch 00001: saving model to training/cp.ckpt
888/888 [==============================] - 1735s 2s/step - loss: 258.0937
Epoch 2/12
888/888 [==============================] - ETA: 0s - loss: 105.9775
Epoch 00002: saving model to training/cp.ckpt
888/888 [==============================] - 1428s 2s/step - loss: 105.9775
Epoch 3/12
888/888 [==============================] - ETA: 0s - loss: 102.4195
Epoch 00003: saving model to training/cp.ckpt
888/888 [==============================] - 1364s 2s/step - loss: 102.4195
Epoch 4/12
888/888 [==============================] - ETA: 0s - loss: 98.4859
Epoch 00004: saving model to training/cp.ckpt
888/888 [==============================] - 1347s 2s/step - loss: 98.4859
Epoch 5/12
888/888 [==============================] - ETA: 0s - loss: 97.5308
Epoch 00005: saving model to training/cp.ckpt
888/888 [==============================] - 1352s 2s/step - loss: 97.5308
Epoch 6/12
888/888 [==============================] - ETA: 0s - loss: 96.0889
Epoch 00006: saving model to training/cp.ckpt
888/888 [==============================] - 1347s 2s/step - loss: 96.0889
Epoch 7/12
888/888 [==============================] - ETA: 0s - loss: 94.7550
Epoch 00007: saving model to training/cp.ckpt
888/888 [==============================] - 1355s 2s/step - loss: 94.7550
Epoch 8/12
888/888 [==============================] - ETA: 0s - loss: 93.3618
Epoch 00008: saving model to training/cp.ckpt
888/888 [==============================] - 1332s 1s/step - loss: 93.3618
Epoch 9/12
888/888 [==============================] - ETA: 0s - loss: 93.5235
Epoch 00009: saving model to training/cp.ckpt
888/888 [==============================] - 1346s 2s/step - loss: 93.5235
Epoch 10/12
888/888 [==============================] - ETA: 0s - loss: 92.4781
Epoch 00010: saving model to training/cp.ckpt
888/888 [==============================] - 1356s 2s/step - loss: 92.4781
Epoch 11/12
888/888 [==============================] - ETA: 0s - loss: 91.5945
Epoch 00011: saving model to training/cp.ckpt
888/888 [==============================] - 1348s 2s/step - loss: 91.5945
Epoch 12/12
888/888 [==============================] - ETA: 0s - loss: 91.0127
Epoch 00012: saving model to training/cp.ckpt
888/888 [==============================] - 1336s 2s/step - loss: 91.0127
<keras.callbacks.History at 0x7fb5c9edecd0>

make sure super_resolution folder exists in Google Drive

%%bash

mkdir -p /content/drive/MyDrive/super_resolution
cp -r training/* /content/drive/MyDrive/super_resolution

save and load the model

path = '/content/drive/MyDrive/super_resolution/model.h5'
model.save(path)
new_model = tf.keras.models.load_model(path)

Evaluation

After train the model for enough time we can evaluate it. Let's pick a random image from the dataset (or you can use anyother image) and transform it into a low resolution image that we can pass to the SRCNN model.

path = np.random.choice(dataset_paths)
image = load_img(path)
image = img_to_array(image)
image = image.astype(np.uint8)
image = tight_crop_image(image, SCALE)
scaled = downsize_upsize_image(image, SCALE)

We need a placeholder where we will put the output patches to create the final image

output = np.zeros(scaled.shape)
height, width = output.shape[:2]

Now we extarct patches from the input image, pass them through the trained model to generate high resolution patch and then put this patch in the right position on the previous placeholder. After processing every patch from the input image we will have a final output image

for y in range(0, height - INPUT_DIM + 1, LABEL_SIZE):
    for x in range(0, width - INPUT_DIM + 1, LABEL_SIZE):
        crop = crop_input(scaled, x, y)
        image_batch = np.expand_dims(crop, axis=0)
        prediction = model.predict(image_batch)
        new_shape = (LABEL_SIZE, LABEL_SIZE, 3)
        prediction = prediction.reshape(new_shape)
        output_y_slice = slice(y + PAD, y + PAD + LABEL_SIZE)
        output_x_slice = slice(x + PAD, x + PAD + LABEL_SIZE)
        output[output_y_slice, output_x_slice] = prediction

Now we can display side by side the low resolution image as well as the resulting output image which is of higher resolution.

figure, axis = plt.subplots(1, 2, figsize=(15, 8))
axis[0].imshow(np.array(scaled,np.int32))
axis[0].set_title('Low resolution image (Downsize + Upsize)')
axis[0].axis('off')

axis[1].imshow(np.array(output,np.int32))
axis[1].set_title('Super resolution result (SRCNN output)')
axis[1].axis('off')
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
(-0.5, 299.5, 197.5, -0.5)

Very impressive result considering the small model that we trained, as you can see it was able to considerably improve the resolution of the input image.