loope

In this article, we will implement a more complete image captioning system on the Flickr8k dataset. We will use an attention mechanism to give the model the power to search for parts of the source caption that are relevant to predict the best next word. Also, using attention will allow us to understand in an intuitive way where the network looks to produce captions.

Let's start by importing the needed dependencies:

import pandas as pd
import numpy as np
import json
import os
import time
from string import punctuation
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import tensorflow as tf
from tensorflow.keras.applications.inception_v3 import *
from tensorflow.keras.layers import *
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.utils import get_file

Set a seed for reproducibility

SEED = 31

Data

We will use the Flickr8k dataset (availble on Kaggle here). So we wil install Kaggle CLI and make sure the credentials are properly configured.

%%capture
%%bash

pip install kaggle --upgrade

mkdir -p ~/.kaggle
echo '{"username":"dzlabs","key":"0abda977ffcfb11ea3726d7c0a6a802e"}' > ~/.kaggle/kaggle.json
chmod 600 ~/.kaggle/kaggle.json

Note: you need to replace KAGGLE_USER with your actual Kaggle username and KAGGLE_KEY with your API key for the download to work.

Download the dataset, unzip the files into, and create a proper folder structure

%%capture
%%bash

kaggle datasets download adityajn105/flickr8k
mkdir -p flickr8k
unzip flickr8k.zip -d flickr8k
mkdir -p flickr8k/features

Now we can set variables with the paths to the images and annotations

BASE_PATH = 'flickr8k'
IMAGES_PATH = f'{BASE_PATH}/Images'
FEATURES_PATH = f'{BASE_PATH}/features'
CAPTIONS_PATH = f'{BASE_PATH}/captions.txt'

Let's read the captions file into a Pandas dataframe and have look to it

captions_df = pd.read_csv(CAPTIONS_PATH)
# captions_df = captions_df.groupby('image').first().reset_index()
captions_df.head()
image caption
0 1000268201_693b08cb0e.jpg A child in a pink dress is climbing up a set o...
1 1000268201_693b08cb0e.jpg A girl going into a wooden building .
2 1000268201_693b08cb0e.jpg A little girl climbing into a wooden playhouse .
3 1000268201_693b08cb0e.jpg A little girl climbing the stairs to her playh...
4 1000268201_693b08cb0e.jpg A little girl in a pink dress going into a woo...

With the images and captions loaded, we can take a random sample and dispaly some images with their respective caption

samples = captions_df.sample(6).reset_index()
figure, axis = plt.subplots(2, 3, figsize=(18, 8))
for index, sample in samples.iterrows():
    image = plt.imread(f'{IMAGES_PATH}/{sample["image"]}')
    title = sample['caption'][:50] + '\n' + sample['caption'][50:]
    row, col = int(index / 3), index % 3
    axis[row, col].imshow(image)
    axis[row, col].set_title(title)
    axis[row, col].axis('off')

We need to clean the text in captions (e.g. removing punctuation) to simplify training, and also adding special tokens <sos> (start of sequence) token to be added at the begning of the text, and <eos> (end of sequence) token added at the end.

def clean_caption(caption, start_token='<sos>', end_token='<eos>'):
    def remove_punctuation(word):
        translation = str.maketrans('', '', punctuation)
        return word.translate(translation)
    def is_valid_word(word):
        return len(word) > 1 and word.isalpha()
    caption = caption.lower().split(' ')
    caption = map(remove_punctuation, caption)
    caption = filter(is_valid_word, caption)
    cleaned_caption = f'{start_token} {" ".join(caption)} {end_token}'
    return cleaned_caption

Now we apply the cleaning function and update all captions text

captions_df['caption'] = captions_df.caption.apply(lambda x: clean_caption(x))

Here we define a helper function that we will use later to get the maximum sequence length

def get_max_length(tensor):
    return max(len(t) for t in tensor)

Model

The model that we will be using for caption as illustrated by the diagram below is composed of three smaller models:

  • Image feature extraction which is simply a CNN previously trained on image classification but without the classification head. The weights of this model are non-trainable.
  • CNN Encoder which takes the image features and produces an embedding that will be learded as the model is trained
  • RNN Decoder with Attention which will uses the image embedding as well as the hidden state propagated as the decoder process the caption tokens

image.png

The first part of the model is the CNN Encoder that will train an dense layer as it processes images feature vector and generates what will be used by the Decoder attention layer

class CNNEncoder(Model):
    def __init__(self, embedding_dim):
        super(CNNEncoder, self).__init__()
        self.fc = Dense(embedding_dim)

    def call(self, x):
        x = self.fc(x)
        x = tf.nn.relu(x)
        return x

The next component of our model is based on Bahdanau's Attention which was a break throught when it was first introduced as an improvemrnt to Encoder-Decoder models. It tries to address the problem that encoder faces as they try to squash information extracted from very long sequences.

To learn more about this algorithm can read the original paper arxiv.org or read a detailed explanation on machinelearningmastery.com

image.png

class BahdanauAttention(Model):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = Dense(units)
        self.W2 = Dense(units)
        self.V = Dense(1)

    def call(self, features, hidden):
        hidden_with_time_axis = tf.expand_dims(hidden, 1)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
        attention_w = tf.nn.softmax(self.V(score), axis=1)
        ctx_vector = attention_w * features
        ctx_vector = tf.reduce_sum(ctx_vector, axis=1)
        return ctx_vector, attention_w

Next, we define the decoder whichi is an RNN that uses GRU and attention to learn how to produce captions from the text input sequences and the visual feature vectors extracted from the input images:

class RNNDecoder(Model):
    def __init__(self, embedding_size, units, vocab_size):
        super(RNNDecoder, self).__init__()
        self.units = units
        self.embedding = Embedding(vocab_size, embedding_size)
        self.gru = GRU(self.units, return_sequences=True, return_state=True, recurrent_initializer='glorot_uniform')
        self.fc1 = Dense(self.units)
        self.fc2 = Dense(vocab_size)
        self.attention = BahdanauAttention(self.units)

    def call(self, x, features, hidden):
        # calcualte the attention
        context_vector, attention_weights = self.attention(features, hidden)
        # calculate the embeddings of the input token
        x = self.embedding(x)
        expanded_context = tf.expand_dims(context_vector, 1)
        x = Concatenate(axis=-1)([expanded_context, x])
        # pass context vector and input embedding through GRU
        output, state = self.gru(x)
        x = self.fc1(output)
        x = tf.reshape(x, (-1, x.shape[2]))
        x = self.fc2(x)
        return x, state, attention_weights

    def reset_state(self, batch_size):
        return tf.zeros((batch_size, self.units))

Finally, we put the different pieces of model togher in the Trainer Class. It will create the encoder, the decoder, the tokenizer, and the optimizer and loss functions needed to train the whole system. This class also defines a function to perform a single training step as well as the training on many epoches. During the training, the training loss will be recorded which later can be visualized with TensorBoard.

class Trainer(object):
    def __init__(self, embedding_size, units, vocab_size, tokenizer):
        self.tokenizer = tokenizer
        self.encoder = CNNEncoder(embedding_size)
        self.decoder = RNNDecoder(embedding_size, units, vocab_size)
        self.optimizer = Adam()
        self.loss = SparseCategoricalCrossentropy(from_logits=True, reduction='none')

    def loss_function(self, real, predicted):
        """Calculate the loss based on the ground truth caption and the predicted one"""
        mask = tf.math.logical_not(tf.math.equal(real,0))
        _loss = self.loss(real, predicted)
        mask = tf.cast(mask, dtype=_loss.dtype)
        _loss *= mask
        return tf.reduce_mean(_loss)

    @tf.function
    def train_step(self, image_tensor, target):
        """Perform one training step"""
        loss = 0
        hidden = self.decoder.reset_state(target.shape[0])
        start_token_idx = self.tokenizer.word_index['<sos>']
        init_batch = [start_token_idx] * target.shape[0]
        decoder_input = tf.expand_dims(init_batch, 1)

        with tf.GradientTape() as tape:
            features = self.encoder(image_tensor)
            for i in range(1, target.shape[1]):
                preds, hidden, _ = self.decoder(decoder_input, features, hidden)
                loss += self.loss_function(target[:, i], preds)
                decoder_input = tf.expand_dims(target[:, i],1)

        total_loss = loss / int(target.shape[1])
        trainable_vars = (self.encoder.trainable_variables + self.decoder.trainable_variables)
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients,trainable_vars))
        return loss, total_loss

    def train(self, dataset, epochs, num_steps):
        """Train and log metrics"""
        writer = tf.summary.create_file_writer('log_dir')
        for epoch in tqdm(range(epochs)):
            start = time.time()
            total_loss = 0
            for batch, (image_tensor, target) in enumerate(dataset):
                batch_loss, step_loss = self.train_step(image_tensor, target)
                total_loss += step_loss

            epoch_time = time.time() - start
            # write the loss value
            with writer.as_default():
                tf.summary.scalar('training loss', total_loss / num_steps, step=epoch+1)
                tf.summary.scalar('Epoch time (s)', epoch_time, step=epoch+1)

Training

Before we can start the training, we need to perform some data pre-process. Let's first get an array of image paths and the corresponding captions

train_images = captions_df.image.apply(lambda image: f'{IMAGES_PATH}/{image}').values
train_captions = captions_df.caption.values

We need to download a pre-trained instance of Inception V3 with the Imagenet dataset and used as the image feature extractor after removing the model classification head

feature_extractor = InceptionV3(include_top=False, weights='imagenet')
feature_extractor = Model(feature_extractor.input, feature_extractor.layers[-1].output)
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
87916544/87910968 [==============================] - 1s 0us/step
87924736/87910968 [==============================] - 1s 0us/step

We need to define a function that will load images based on their path and resize them as expected by the feature extractor model

def load_image_fn(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (299, 299))
    image = preprocess_input(image)
    return image, image_path

We can use the previous funciton to Create a tf.data.Dataset of the images

BATCH_SIZE = 8
image_dataset = tf.data.Dataset.from_tensor_slices(train_images)\
    .map(load_image_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
    .batch(BATCH_SIZE)

We iteratate over all the images in the dataset and pass them through the feature extractor. As the extracted feature vectors cannot fit in memory we store them under FEATURES_PATH folder

for image, path in tqdm(image_dataset):
    batch_features = feature_extractor.predict(image)
    batch_features = tf.reshape(batch_features, (batch_features.shape[0], -1, batch_features.shape[3]))
    for batch_feature, p in zip(batch_features, path):
        feature_path = Path(p.numpy().decode('UTF-8'))
        image_name = feature_path.stem
        np.save(f'{FEATURES_PATH}/{image_name}', batch_feature.numpy())
100%|██████████| 5057/5057 [11:51<00:00,  7.11it/s]

We need to create tokens from the captions, hence we train a tokenizer on the top 5,000 words in our captions which will become our vocabulary. Note we could take more words but that lead to bigger memory footprint as we one hot encode each token.

After that, we apply this tokenizer to each caption text to generate a numeric sequence. We limit the sequence size and we pad any short caption by adding a sequence of the special <pad> token to the end.

top_k = 5000
filters = '!”#$%&()*+.,-/:;=?@[\]^_`{|}~ '
tokenizer = Tokenizer(num_words=top_k, oov_token='<unk>', filters=filters)
tokenizer.fit_on_texts(train_captions)
tokenizer.word_index['<pad>'] = 0
tokenizer.index_word[0] = '<pad>'
train_seqs = tokenizer.texts_to_sequences(train_captions)
captions_seqs = pad_sequences(train_seqs, padding='post')
max_length = get_max_length(train_seqs)
captions_seqs.shape
(40455, 34)

Let's split the dataset into 80% for the actual training and 20% for later evaluating the trained model

(images_train, images_val, caption_train, caption_val) = train_test_split(train_images, captions_seqs, test_size=0.2, random_state=SEED)

We need a function that will load an image feature vector and the associated caption

def load_example_fn(image_name, caption):
    image_name = image_name.decode('utf-8')
    image_name = Path(image_name).stem
    image_tensor = np.load(f'{FEATURES_PATH}/{image_name}.npy')
    return image_tensor, caption

To create the trainig dataset, we batch the images with their captions into batches of BATCH_SIZE example. We use the load_example_fn load the feature vectors. For performance, we suffle the dataset and pre-fetch some of them into the GPU to speedup training.

BATCH_SIZE = 64
BUFFER_SIZE = 1000
dataset = tf.data.Dataset.from_tensor_slices((images_train, caption_train))\
    .map(lambda i1, i2: tf.numpy_function(load_example_fn, [i1, i2], [tf.float32, tf.int32]), num_parallel_calls=tf.data.experimental.AUTOTUNE)\
    .shuffle(BUFFER_SIZE)\
    .batch(BATCH_SIZE)\
    .prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

Now, we are ready for the actually training. Let's create the helper Trainer and set the token embeddings size to 256 elements, and the number of units for the decoder as well as the attention model to 512. We also pass the tokenizer and the vocabulary size is 5000 + 1 (for the padding token).

trainer = Trainer(embedding_size=256, units=512, vocab_size=top_k + 1, tokenizer=tokenizer)

Now we train for few epochs (note this may take quite some time to finish)

EPOCHS = 30
num_steps = len(images_train) // BATCH_SIZE
trainer.train(dataset, EPOCHS, num_steps)
100%|██████████| 30/30 [2:00:15<00:00, 240.52s/it]

Now we can explore the training loss with

%load_ext tensorboard
%tensorboard --logdir log_dir

image.png

Evaluation

For the evaluation, we will define a function that takes the image feature extractor and the trained encoder, decoder. It will transform the image and generate a caption starting with the <sos> spectial token. The caption generation stops when and <eos> token is generated by the decoder or the maximum length of caption tokens is reached.

def evaluate(encoder, decoder, tokenizer, image_path, max_length, attention_shape):
    attention_plot = np.zeros((max_length, attention_shape))
    # initialize hidden state
    hidden = decoder.reset_state(batch_size=1)
    image, _ = load_image_fn(image_path)
    # extarct image feature vector
    features = feature_extractor(tf.expand_dims(image, 0))
    features = tf.reshape(features, (features.shape[0], -1, features.shape[3]))
    # encode the features
    encoder_out = encoder(features)
    start_token_idx = tokenizer.word_index['<sos>']
    decoder_input = tf.expand_dims([start_token_idx], 0)
    result = []

    # generate the caption
    for i in range(max_length):
        (preds, hidden, attention_w) = decoder(decoder_input, encoder_out, hidden)
        attention_plot[i] = tf.reshape(attention_w, (-1,)).numpy()
        pred_id = tf.random.categorical(preds, 1)[0][0].numpy()
        result.append(tokenizer.index_word[pred_id])

        if tokenizer.index_word[pred_id] == '<eos>':
            return result, attention_plot
        decoder_input = tf.expand_dims([pred_id], 0)

    attention_plot = attention_plot[:len(result), :]
    return result, attention_plot

Note: see how for each token we generate an attention plot and added to the final attention_plot.

attention_features_shape = 64

Pick a random image from the evaluation dataset that we saved earlier

random_id = np.random.randint(0, len(images_val))
image_path = images_val[random_id]

Get and clean the actual image caption

actual_caption = ' '.join([tokenizer.index_word[i] for i in caption_val[random_id] if i != 0])
actual_caption = (actual_caption.replace('<sos>', '').replace('<eos>', ''))

Generate a caption for the image and clean it from special tokens

result, attention_plot = evaluate(trainer.encoder, trainer.decoder, tokenizer, image_path, max_length, attention_features_shape)
predicted_caption = (' '.join(result).replace('<sos>', '').replace('<eos>', ''))

Let's show side by side the image with its original caption, and next to it two images overlayed with some of the attention plots (there is one attention plot per output token in the caption) and the predicted caption

figure, axis = plt.subplots(1, 3, figsize=(15, 8))

axis[0].imshow(plt.imread(image_path))
axis[0].set_title(actual_caption)
axis[0].axis('off')

imageshow = axis[1].imshow(plt.imread(image_path))
axis[1].imshow(np.resize(attention_plot[0], (8, 8)), cmap='gray', alpha=0.6, extent=imageshow.get_extent())
axis[1].set_title(predicted_caption)
axis[1].axis('off')

imageshow = axis[2].imshow(plt.imread(image_path))
axis[2].imshow(np.resize(attention_plot[len(attention_plot)-1], (8, 8)), cmap='gray', alpha=0.6, extent=imageshow.get_extent())
axis[2].set_title(predicted_caption)
axis[2].axis('off')

plt.show()

Notice how well the model performed and generated a caption that's close to the actual ground truth

We can display all attention plots and inspect what the model was looking into when it generated the corresponding token. For this, let's define a function that will receive an image, the caption as sequence of tokens, and the attention_plot returned by the previous evualtion function.

def plot_attention(image_path, result, attention_plot, output_path):
    image_array = plt.imread(image_path)
    fig = plt.figure(figsize=(10, 10))
    # for each token create a sub-plot and display the corresponding attention
    for l in range(len(result)):
        temp_att = np.resize(attention_plot[l], (8, 8))
        ax = fig.add_subplot(len(result) // 2, len(result) // 2, l + 1)
        ax.set_title(result[l])
        image = ax.imshow(image_array)
        ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=image.get_extent())
    # save the attention plot
    plt.savefig(output_path, format='png')
    plt.tight_layout()
    plt.show()
plot_attention(image_path, result, attention_plot, './attention_plot.png')

Note: the square areas in the plots represent the areas of the picture the model paid more attention to when generate the tokens. For instance, to produce the word women, the network looked at the head of women in the photo. Also, we can see that when the network generated the word beside as it looked at the chair.