In this article, we will implement a U-Net model (as depicted in the diagram below) and trained on a popular image segmentation dataset. Training a U-Net from scratch is a hard, so instead we will leverage transfer learning to get good result after only few epochs of training.

For reference, you can read the original U-Net paper arxiv.org.

u-net architecture

Before start, import the needed dependencies

import cv2
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import tensorflow as tf
import tensorflow_datasets as tfdata
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import *
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import RMSprop

Data

We will use a simple segmentation dataset known as Oxford-IIIT Pet Dataset. For quick introduction the dataset contains images of dogs or cats along with a segmentation image. Each pixel of the segmentation belongs to one of the following classes:

  • 1: The pixel belongs to a pet (i.e. cat or dog).
  • 2: The pixel belongs to the contour of a pet.
  • 3: The pixel belongs to the surroundings.

For detailed introduction to the dataset check its website here.

The dataset along with its metadata is available in tensorflow-datasets here with images already preprocessed and ready to use with the TensorFlow Data API. Let's download this dataset and cache it locally

dataset, info = tfdata.load('oxford_iiit_pet', with_info=True)
Downloading and preparing dataset oxford_iiit_pet/3.2.0 (download: 773.52 MiB, generated: 774.69 MiB, total: 1.51 GiB) to /root/tensorflow_datasets/oxford_iiit_pet/3.2.0...



Shuffling and writing examples to /root/tensorflow_datasets/oxford_iiit_pet/3.2.0.incomplete3ZMH6C/oxford_iiit_pet-train.tfrecord
Shuffling and writing examples to /root/tensorflow_datasets/oxford_iiit_pet/3.2.0.incomplete3ZMH6C/oxford_iiit_pet-test.tfrecord
Dataset oxford_iiit_pet downloaded and prepared to /root/tensorflow_datasets/oxford_iiit_pet/3.2.0. Subsequent calls will reuse this data.

Let's examine few images and the ground truth labels from the dataset

figure, axis = plt.subplots(2, 2, figsize=(10, 8))
for row, example in enumerate(dataset['train'].take(2)):
    file_name = example['file_name'].numpy().decode('utf8')
    axis[row, 0].set_title(file_name)
    axis[row, 0].imshow(example['image'].numpy())
    axis[row, 0].axis('off')
    mask = np.squeeze(example['segmentation_mask'].numpy(), axis=2)
    axis[row, 1].set_title('mask')
    axis[row, 1].imshow(mask, cmap='gray')
    axis[row, 1].axis('off')
    

Model

As you can see in from the U-Net model architecture, the model consists of an Encoder depicted by a contracting path (left side) and a Decoder depicted by an expansive path (right side). From the left side, some skip connections are passed to the right side in order to improve the performance of the decoding.

Because the Encoder is very similar to a convolutional network, instead of creating this part and training it form scratch we will use a pretrained model and just select the output layers with appropriate shape to make skip connection to the Decoder. This way, we can train the model faster as we will have only the upsaming path to train.

First, let's define a function that creates the Encoder. We will use a pretrained MobileNetV2 on ImageNet as backbone for the Encoder. The output of the Encoder consits of few cherry picked layers that we will use later to create the skip connection from this Encoder to the Decoder part of the model.

def create_down_path(input_size=(256, 256, 3)):
    """Create down path of U-Net model"""
    backbone = MobileNetV2(input_shape=input_size, include_top=False, weights='imagenet')
    target_layers = [
                          'block_1_expand_relu',
                          'block_3_expand_relu',
                          'block_6_expand_relu',
                          'block_13_expand_relu',
                          'block_16_project'
                          ]
    layers = [backbone.get_layer(l).output for l in target_layers]
    encoder = Model(inputs=backbone.input, outputs=layers)
    encoder.trainable = False
    return encoder

Note: with model.trainable = False we are freezing the weights of the Encoder as we don’t want to train it as it all the used layers were already trained on ImageNet.

Second, we define a function that creates the up-sampling path which consists of a sequence of blocks that uses the Conv2DTranspose layer.

def create_up_path(size=4, dropout=False):
    decoder = []
    init = tf.random_normal_initializer(0.0, 0.02)
    for filters in (512, 256, 128, 64):
        block = Sequential()
        block.add(Conv2DTranspose(filters=filters, kernel_size=size, strides=2, padding='same', kernel_initializer=init, use_bias=False))
        block.add(BatchNormalization())
        if dropout:
            block.add(Dropout(rate=0.5))
        block.add(ReLU())
        decoder.append(block)
    return decoder

Finally, we use the previous functions to create the Encoder and Decoder components, and wire them with the skip connections.

def create_unet(input_size=(256, 256, 3)):
    down_stack = create_down_path(input_size)
    up_stack = create_up_path()
    # create skip connections
    inputs = Input(shape=input_size)
    x = inputs
    skip_layers = down_stack(x)
    x = skip_layers[-1]
    skip_layers = reversed(skip_layers[:-1])
    for up, skip_connection in zip(up_stack, skip_layers):
        x = up(x)
        x = Concatenate()([x, skip_connection])
    # output layer
    init = tf.random_normal_initializer(0.0, 0.02)
    output = Conv2DTranspose(filters=3, kernel_size=3, strides=2, padding='same', kernel_initializer=init)(x)
    return Model(inputs, outputs=output)

Note: The intput of the U-Net model has same dimension as the input image, and as many channels as the number of segmentation classes (because each pixel can be categorized into one of 3 classes)

Now we can create the model and inspect it, notice how the output layers from the Encoder are used by the different blocks of the decoder.

unet_model = create_unet()
WARNING:tensorflow:`input_shape` is undefined or non-square, or `rows` is not in [96, 128, 160, 192, 224]. Weights for input shape (224, 224) will be loaded as the default.
WARNING:tensorflow:`input_shape` is undefined or non-square, or `rows` is not in [96, 128, 160, 192, 224]. Weights for input shape (224, 224) will be loaded as the default.
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step
9420800/9406464 [==============================] - 0s 0us/step
tf.keras.utils.plot_model(unet_model, show_shapes=True)

Let's compile the model to get ready for training. We will use RMSProp as optimizer and SparseCategoricalCrossentropy as the loss function.

unet_model.compile(optimizer=RMSprop(), loss=SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

Training

We need to normalize the images so they match the expected input of the pretrained MobileNetV2 model

def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_mask -= 1
    return input_image, input_mask

Let's define a function that will be used by the TF Dataset to load images and their labels, normalize the images and perform some basic augmentation (flip the image and mask) only for training.

@tf.function
def load_image_fn(example, train=True):
    input_image = tf.image.resize(example['image'],(256, 256))
    input_mask = tf.image.resize(example['segmentation_mask'], (256,256))
    if train and np.random.uniform() > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)
    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask

Let's define some training parameters like batch size, steps per epochs, etc.

TRAIN_SIZE = info.splits['train'].num_examples
VALIDATION_SIZE = info.splits['test'].num_examples
BATCH_SIZE = 64
STEPS_PER_EPOCH = TRAIN_SIZE // BATCH_SIZE
VALIDATION_SUBSPLITS = 5
VALIDATION_STEPS = VALIDATION_SIZE // BATCH_SIZE
VALIDATION_STEPS //= VALIDATION_SUBSPLITS
BUFFER_SIZE = 1000

Now we create the training and test datasets and use the previously defined functions to load the images and labels. We also batch and perfom a pre-fetch for the training dataset

train_dataset = dataset['train']\
.map(load_image_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
.cache()\
.shuffle(BUFFER_SIZE)\
.batch(BATCH_SIZE)\
.repeat()\
.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

test_dataset = dataset['test']\
.map(lambda d: load_image_fn(d, train=False), num_parallel_calls=tf.data.experimental.AUTOTUNE)\
.batch(BATCH_SIZE)

Now we can start the training of our U-Net model

hist = unet_model.fit(train_dataset, epochs=10, steps_per_epoch=STEPS_PER_EPOCH, validation_steps=VALIDATION_STEPS, validation_data=test_dataset)
Epoch 1/10
57/57 [==============================] - 53s 582ms/step - loss: 0.4203 - accuracy: 0.8352 - val_loss: 2.0189 - val_accuracy: 0.6889
Epoch 2/10
57/57 [==============================] - 34s 571ms/step - loss: 0.2736 - accuracy: 0.8894 - val_loss: 0.4313 - val_accuracy: 0.8765
Epoch 3/10
57/57 [==============================] - 33s 576ms/step - loss: 0.2554 - accuracy: 0.8945 - val_loss: 0.2991 - val_accuracy: 0.8937
Epoch 4/10
57/57 [==============================] - 33s 585ms/step - loss: 0.2449 - accuracy: 0.8978 - val_loss: 0.2617 - val_accuracy: 0.9006
Epoch 5/10
57/57 [==============================] - 34s 592ms/step - loss: 0.2354 - accuracy: 0.9008 - val_loss: 0.3233 - val_accuracy: 0.8708
Epoch 6/10
57/57 [==============================] - 34s 599ms/step - loss: 0.2268 - accuracy: 0.9042 - val_loss: 0.2499 - val_accuracy: 0.9042
Epoch 7/10
57/57 [==============================] - 34s 606ms/step - loss: 0.2244 - accuracy: 0.9045 - val_loss: 0.2453 - val_accuracy: 0.9022
Epoch 8/10
57/57 [==============================] - 35s 607ms/step - loss: 0.2125 - accuracy: 0.9092 - val_loss: 0.2443 - val_accuracy: 0.9057
Epoch 9/10
57/57 [==============================] - 34s 606ms/step - loss: 0.2124 - accuracy: 0.9089 - val_loss: 0.2930 - val_accuracy: 0.8884
Epoch 10/10
57/57 [==============================] - 34s 606ms/step - loss: 0.2080 - accuracy: 0.9102 - val_loss: 0.2561 - val_accuracy: 0.9016

MobileNets have many different input size configurations 128, 160, 192 and 224, in our case we are sizing the images into 256 which is not supported hence the warnings in the output. Thre is nothing to worry about the model is trained properly and validation is performed.

Now we can examine the losses and accuracy progress per epoch

figure, axis = plt.subplots(1, 2, figsize=(15, 5))
# accuracy
axis[0].plot(hist.history['accuracy'])
axis[0].plot(hist.history['val_accuracy'])
axis[0].set_title('model accuracy')
axis[0].set_ylabel('accuracy')
axis[0].set_xlabel('epoch')
axis[0].legend(['train', 'valid'], loc='upper left')

# loss
axis[1].plot(hist.history['loss'])
axis[1].plot(hist.history['val_loss'])
axis[1].set_title('model loss')
axis[1].set_ylabel('loss')
axis[1].set_xlabel('epoch')
axis[1].legend(['train', 'valid'], loc='upper left')


plt.show()

Evaluation

We can evaluate our U-Net model on the entire test dataset simply as follows

result = unet_model.evaluate(test_dataset)
58/58 [==============================] - 16s 277ms/step - loss: 0.2512 - accuracy: 0.9024
Accuracy: 90.24%
print(f'Test Accuracy: {result[1] * 100:.2f}%')
Test Accuracy: 90.24%

For a visual examination and an inspection of the precited segmentation, we need some helper function to transform the output of the trained U-Net model into an actual valid mask, hence the following functions

def process_mask(mask):
    mask = (mask.numpy() * 127.5).astype('uint8')
    mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
    return mask

def create_mask(prediction_mask):
    prediction_mask = tf.argmax(prediction_mask, axis=-1)
    prediction_mask = prediction_mask[..., tf.newaxis]
    return prediction_mask[0]

From the test dataset we take sample images and pass them through the U-Net model to generate predictions that later we will compare against the ground truth masks.

def generate_predictions(model, dataset, sample_size=2):
    output = []
    for image, mask in tqdm(dataset.take(sample_size)):
        output_mask = model.predict(image)
        predicted_mask = process_mask(create_mask(output_mask))
        image = (image[0].numpy() * 255.0).astype('uint8')
        ground_truth_mask = process_mask(mask[0])
        output.append((image, ground_truth_mask, predicted_mask))
    return output

The following helper function will be used to plot a set of predictions

def plot_predictions(predictions):
    figure, axis = plt.subplots(len(predictions), 3, figsize=(10, 20))
    for row, (image, ground_truth_mask, predicted_mask) in enumerate(predictions):
        # plot the image
        axis[row, 0].imshow(image)
        axis[row, 0].axis('off')
        axis[row, 0].set_title('image')
        # plot the ground truth mask
        axis[row, 1].imshow(ground_truth_mask)
        axis[row, 1].axis('off')
        axis[row, 1].set_title('ground truth mask')
        # plot the predicted mask
        axis[row, 2].imshow(predicted_mask)
        axis[row, 2].axis('off')
        axis[row, 2].set_title('predicted mask')
    plt.tight_layout()
    plt.show()

Now we can generate predicted masks for a sample of the images in the test set and plot them for inspection

predictions = generate_predictions(unet_model, test_dataset, 5)
100%|██████████| 5/5 [00:03<00:00,  1.33it/s]
plot_predictions(predictions)

As you can see from the predicted masks, the model was able to do a good job even if it was trained on very few epochs. It did very well on the easy example but could not do a good job in the harder exmaples. For instance on darker images or when the image contains more than just a pet.

I put lot efforts in writting every article, I hope you found this one useful as well as easy to digest.

Feel free to leave a comment or reach out on twitter @bachiirc