In this article, we will implement a more complete image captioning system on the Flickr8k dataset. We will use an attention mechanism to give the model the power to search for parts of the source caption that are relevant to predict the best next word. Also, using attention will allow us to understand in an intuitive way where the network looks to produce captions.
Let's start by importing the needed dependencies:
import pandas as pd
import numpy as np
import json
import os
import time
from string import punctuation
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import tensorflow as tf
from tensorflow.keras.applications.inception_v3 import *
from tensorflow.keras.layers import *
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.utils import get_file
Set a seed for reproducibility
SEED = 31
We will use the Flickr8k dataset (availble on Kaggle here). So we wil install Kaggle CLI and make sure the credentials are properly configured.
%%capture
%%bash
pip install kaggle --upgrade
mkdir -p ~/.kaggle
echo '{"username":"dzlabs","key":"0abda977ffcfb11ea3726d7c0a6a802e"}' > ~/.kaggle/kaggle.json
chmod 600 ~/.kaggle/kaggle.json
KAGGLE_USER
with your actual Kaggle username and KAGGLE_KEY
with your API key for the download to work.
Download the dataset, unzip the files into, and create a proper folder structure
%%capture
%%bash
kaggle datasets download adityajn105/flickr8k
mkdir -p flickr8k
unzip flickr8k.zip -d flickr8k
mkdir -p flickr8k/features
Now we can set variables with the paths to the images and annotations
BASE_PATH = 'flickr8k'
IMAGES_PATH = f'{BASE_PATH}/Images'
FEATURES_PATH = f'{BASE_PATH}/features'
CAPTIONS_PATH = f'{BASE_PATH}/captions.txt'
Let's read the captions file into a Pandas dataframe and have look to it
captions_df = pd.read_csv(CAPTIONS_PATH)
# captions_df = captions_df.groupby('image').first().reset_index()
captions_df.head()
With the images and captions loaded, we can take a random sample and dispaly some images with their respective caption
samples = captions_df.sample(6).reset_index()
figure, axis = plt.subplots(2, 3, figsize=(18, 8))
for index, sample in samples.iterrows():
image = plt.imread(f'{IMAGES_PATH}/{sample["image"]}')
title = sample['caption'][:50] + '\n' + sample['caption'][50:]
row, col = int(index / 3), index % 3
axis[row, col].imshow(image)
axis[row, col].set_title(title)
axis[row, col].axis('off')
We need to clean the text in captions (e.g. removing punctuation) to simplify training, and also adding special tokens <sos>
(start of sequence) token to be added at the begning of the text, and <eos>
(end of sequence) token added at the end.
def clean_caption(caption, start_token='<sos>', end_token='<eos>'):
def remove_punctuation(word):
translation = str.maketrans('', '', punctuation)
return word.translate(translation)
def is_valid_word(word):
return len(word) > 1 and word.isalpha()
caption = caption.lower().split(' ')
caption = map(remove_punctuation, caption)
caption = filter(is_valid_word, caption)
cleaned_caption = f'{start_token} {" ".join(caption)} {end_token}'
return cleaned_caption
Now we apply the cleaning function and update all captions text
captions_df['caption'] = captions_df.caption.apply(lambda x: clean_caption(x))
Here we define a helper function that we will use later to get the maximum sequence length
def get_max_length(tensor):
return max(len(t) for t in tensor)
The model that we will be using for caption as illustrated by the diagram below is composed of three smaller models:
- Image feature extraction which is simply a CNN previously trained on image classification but without the classification head. The weights of this model are non-trainable.
- CNN Encoder which takes the image features and produces an embedding that will be learded as the model is trained
- RNN Decoder with Attention which will uses the image embedding as well as the hidden state propagated as the decoder process the caption tokens
The first part of the model is the CNN Encoder that will train an dense layer as it processes images feature vector and generates what will be used by the Decoder attention layer
class CNNEncoder(Model):
def __init__(self, embedding_dim):
super(CNNEncoder, self).__init__()
self.fc = Dense(embedding_dim)
def call(self, x):
x = self.fc(x)
x = tf.nn.relu(x)
return x
The next component of our model is based on Bahdanau's Attention which was a break throught when it was first introduced as an improvemrnt to Encoder-Decoder models. It tries to address the problem that encoder faces as they try to squash information extracted from very long sequences.
To learn more about this algorithm can read the original paper arxiv.org or read a detailed explanation on machinelearningmastery.com
class BahdanauAttention(Model):
def __init__(self, units):
super(BahdanauAttention, self).__init__()
self.W1 = Dense(units)
self.W2 = Dense(units)
self.V = Dense(1)
def call(self, features, hidden):
hidden_with_time_axis = tf.expand_dims(hidden, 1)
score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
attention_w = tf.nn.softmax(self.V(score), axis=1)
ctx_vector = attention_w * features
ctx_vector = tf.reduce_sum(ctx_vector, axis=1)
return ctx_vector, attention_w
Next, we define the decoder whichi is an RNN that uses GRU and attention to learn how to produce captions from the text input sequences and the visual feature vectors extracted from the input images:
class RNNDecoder(Model):
def __init__(self, embedding_size, units, vocab_size):
super(RNNDecoder, self).__init__()
self.units = units
self.embedding = Embedding(vocab_size, embedding_size)
self.gru = GRU(self.units, return_sequences=True, return_state=True, recurrent_initializer='glorot_uniform')
self.fc1 = Dense(self.units)
self.fc2 = Dense(vocab_size)
self.attention = BahdanauAttention(self.units)
def call(self, x, features, hidden):
# calcualte the attention
context_vector, attention_weights = self.attention(features, hidden)
# calculate the embeddings of the input token
x = self.embedding(x)
expanded_context = tf.expand_dims(context_vector, 1)
x = Concatenate(axis=-1)([expanded_context, x])
# pass context vector and input embedding through GRU
output, state = self.gru(x)
x = self.fc1(output)
x = tf.reshape(x, (-1, x.shape[2]))
x = self.fc2(x)
return x, state, attention_weights
def reset_state(self, batch_size):
return tf.zeros((batch_size, self.units))
Finally, we put the different pieces of model togher in the Trainer
Class. It will create the encoder, the decoder, the tokenizer, and the optimizer and loss functions needed to train the whole system. This class also defines a function to perform a single training step as well as the training on many epoches. During the training, the training loss will be recorded which later can be visualized with TensorBoard.
class Trainer(object):
def __init__(self, embedding_size, units, vocab_size, tokenizer):
self.tokenizer = tokenizer
self.encoder = CNNEncoder(embedding_size)
self.decoder = RNNDecoder(embedding_size, units, vocab_size)
self.optimizer = Adam()
self.loss = SparseCategoricalCrossentropy(from_logits=True, reduction='none')
def loss_function(self, real, predicted):
"""Calculate the loss based on the ground truth caption and the predicted one"""
mask = tf.math.logical_not(tf.math.equal(real,0))
_loss = self.loss(real, predicted)
mask = tf.cast(mask, dtype=_loss.dtype)
_loss *= mask
return tf.reduce_mean(_loss)
@tf.function
def train_step(self, image_tensor, target):
"""Perform one training step"""
loss = 0
hidden = self.decoder.reset_state(target.shape[0])
start_token_idx = self.tokenizer.word_index['<sos>']
init_batch = [start_token_idx] * target.shape[0]
decoder_input = tf.expand_dims(init_batch, 1)
with tf.GradientTape() as tape:
features = self.encoder(image_tensor)
for i in range(1, target.shape[1]):
preds, hidden, _ = self.decoder(decoder_input, features, hidden)
loss += self.loss_function(target[:, i], preds)
decoder_input = tf.expand_dims(target[:, i],1)
total_loss = loss / int(target.shape[1])
trainable_vars = (self.encoder.trainable_variables + self.decoder.trainable_variables)
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients,trainable_vars))
return loss, total_loss
def train(self, dataset, epochs, num_steps):
"""Train and log metrics"""
writer = tf.summary.create_file_writer('log_dir')
for epoch in tqdm(range(epochs)):
start = time.time()
total_loss = 0
for batch, (image_tensor, target) in enumerate(dataset):
batch_loss, step_loss = self.train_step(image_tensor, target)
total_loss += step_loss
epoch_time = time.time() - start
# write the loss value
with writer.as_default():
tf.summary.scalar('training loss', total_loss / num_steps, step=epoch+1)
tf.summary.scalar('Epoch time (s)', epoch_time, step=epoch+1)
Before we can start the training, we need to perform some data pre-process. Let's first get an array of image paths and the corresponding captions
train_images = captions_df.image.apply(lambda image: f'{IMAGES_PATH}/{image}').values
train_captions = captions_df.caption.values
We need to download a pre-trained instance of Inception V3 with the Imagenet dataset and used as the image feature extractor after removing the model classification head
feature_extractor = InceptionV3(include_top=False, weights='imagenet')
feature_extractor = Model(feature_extractor.input, feature_extractor.layers[-1].output)
We need to define a function that will load images based on their path and resize them as expected by the feature extractor model
def load_image_fn(image_path):
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, (299, 299))
image = preprocess_input(image)
return image, image_path
We can use the previous funciton to Create a tf.data.Dataset
of the images
BATCH_SIZE = 8
image_dataset = tf.data.Dataset.from_tensor_slices(train_images)\
.map(load_image_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
.batch(BATCH_SIZE)
We iteratate over all the images in the dataset and pass them through the feature extractor. As the extracted feature vectors cannot fit in memory we store them under FEATURES_PATH
folder
for image, path in tqdm(image_dataset):
batch_features = feature_extractor.predict(image)
batch_features = tf.reshape(batch_features, (batch_features.shape[0], -1, batch_features.shape[3]))
for batch_feature, p in zip(batch_features, path):
feature_path = Path(p.numpy().decode('UTF-8'))
image_name = feature_path.stem
np.save(f'{FEATURES_PATH}/{image_name}', batch_feature.numpy())
We need to create tokens from the captions, hence we train a tokenizer on the top 5,000 words in our captions which will become our vocabulary. Note we could take more words but that lead to bigger memory footprint as we one hot encode each token.
After that, we apply this tokenizer to each caption text to generate a numeric sequence. We limit the sequence size and we pad any short caption by adding a sequence of the special <pad>
token to the end.
top_k = 5000
filters = '!”#$%&()*+.,-/:;=?@[\]^_`{|}~ '
tokenizer = Tokenizer(num_words=top_k, oov_token='<unk>', filters=filters)
tokenizer.fit_on_texts(train_captions)
tokenizer.word_index['<pad>'] = 0
tokenizer.index_word[0] = '<pad>'
train_seqs = tokenizer.texts_to_sequences(train_captions)
captions_seqs = pad_sequences(train_seqs, padding='post')
max_length = get_max_length(train_seqs)
captions_seqs.shape
Let's split the dataset into 80% for the actual training and 20% for later evaluating the trained model
(images_train, images_val, caption_train, caption_val) = train_test_split(train_images, captions_seqs, test_size=0.2, random_state=SEED)
We need a function that will load an image feature vector and the associated caption
def load_example_fn(image_name, caption):
image_name = image_name.decode('utf-8')
image_name = Path(image_name).stem
image_tensor = np.load(f'{FEATURES_PATH}/{image_name}.npy')
return image_tensor, caption
To create the trainig dataset, we batch the images with their captions into batches of BATCH_SIZE
example. We use the load_example_fn
load the feature vectors. For performance, we suffle the dataset and pre-fetch some of them into the GPU to speedup training.
BATCH_SIZE = 64
BUFFER_SIZE = 1000
dataset = tf.data.Dataset.from_tensor_slices((images_train, caption_train))\
.map(lambda i1, i2: tf.numpy_function(load_example_fn, [i1, i2], [tf.float32, tf.int32]), num_parallel_calls=tf.data.experimental.AUTOTUNE)\
.shuffle(BUFFER_SIZE)\
.batch(BATCH_SIZE)\
.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
Now, we are ready for the actually training. Let's create the helper Trainer
and set the token embeddings size to 256 elements, and the number of units for the decoder as well as the attention model to 512. We also pass the tokenizer and the vocabulary size is 5000 + 1 (for the padding token).
trainer = Trainer(embedding_size=256, units=512, vocab_size=top_k + 1, tokenizer=tokenizer)
Now we train for few epochs (note this may take quite some time to finish)
EPOCHS = 30
num_steps = len(images_train) // BATCH_SIZE
trainer.train(dataset, EPOCHS, num_steps)
Now we can explore the training loss with
%load_ext tensorboard
%tensorboard --logdir log_dir
For the evaluation, we will define a function that takes the image feature extractor and the trained encoder, decoder. It will transform the image and generate a caption starting with the <sos>
spectial token. The caption generation stops when and <eos>
token is generated by the decoder or the maximum length of caption tokens is reached.
def evaluate(encoder, decoder, tokenizer, image_path, max_length, attention_shape):
attention_plot = np.zeros((max_length, attention_shape))
# initialize hidden state
hidden = decoder.reset_state(batch_size=1)
image, _ = load_image_fn(image_path)
# extarct image feature vector
features = feature_extractor(tf.expand_dims(image, 0))
features = tf.reshape(features, (features.shape[0], -1, features.shape[3]))
# encode the features
encoder_out = encoder(features)
start_token_idx = tokenizer.word_index['<sos>']
decoder_input = tf.expand_dims([start_token_idx], 0)
result = []
# generate the caption
for i in range(max_length):
(preds, hidden, attention_w) = decoder(decoder_input, encoder_out, hidden)
attention_plot[i] = tf.reshape(attention_w, (-1,)).numpy()
pred_id = tf.random.categorical(preds, 1)[0][0].numpy()
result.append(tokenizer.index_word[pred_id])
if tokenizer.index_word[pred_id] == '<eos>':
return result, attention_plot
decoder_input = tf.expand_dims([pred_id], 0)
attention_plot = attention_plot[:len(result), :]
return result, attention_plot
attention_plot
.
attention_features_shape = 64
Pick a random image from the evaluation dataset that we saved earlier
random_id = np.random.randint(0, len(images_val))
image_path = images_val[random_id]
Get and clean the actual image caption
actual_caption = ' '.join([tokenizer.index_word[i] for i in caption_val[random_id] if i != 0])
actual_caption = (actual_caption.replace('<sos>', '').replace('<eos>', ''))
Generate a caption for the image and clean it from special tokens
result, attention_plot = evaluate(trainer.encoder, trainer.decoder, tokenizer, image_path, max_length, attention_features_shape)
predicted_caption = (' '.join(result).replace('<sos>', '').replace('<eos>', ''))
Let's show side by side the image with its original caption, and next to it two images overlayed with some of the attention plots (there is one attention plot per output token in the caption) and the predicted caption
figure, axis = plt.subplots(1, 3, figsize=(15, 8))
axis[0].imshow(plt.imread(image_path))
axis[0].set_title(actual_caption)
axis[0].axis('off')
imageshow = axis[1].imshow(plt.imread(image_path))
axis[1].imshow(np.resize(attention_plot[0], (8, 8)), cmap='gray', alpha=0.6, extent=imageshow.get_extent())
axis[1].set_title(predicted_caption)
axis[1].axis('off')
imageshow = axis[2].imshow(plt.imread(image_path))
axis[2].imshow(np.resize(attention_plot[len(attention_plot)-1], (8, 8)), cmap='gray', alpha=0.6, extent=imageshow.get_extent())
axis[2].set_title(predicted_caption)
axis[2].axis('off')
plt.show()
Notice how well the model performed and generated a caption that's close to the actual ground truth
We can display all attention plots and inspect what the model was looking into when it generated the corresponding token. For this, let's define a function that will receive an image, the caption as sequence of tokens, and the attention_plot
returned by the previous evualtion function.
def plot_attention(image_path, result, attention_plot, output_path):
image_array = plt.imread(image_path)
fig = plt.figure(figsize=(10, 10))
# for each token create a sub-plot and display the corresponding attention
for l in range(len(result)):
temp_att = np.resize(attention_plot[l], (8, 8))
ax = fig.add_subplot(len(result) // 2, len(result) // 2, l + 1)
ax.set_title(result[l])
image = ax.imshow(image_array)
ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=image.get_extent())
# save the attention plot
plt.savefig(output_path, format='png')
plt.tight_layout()
plt.show()
plot_attention(image_path, result, attention_plot, './attention_plot.png')
beside
as it looked at the chair.