CLIP (Contrastive Language-Image Pre-training) was created by OpenAI. This model is trained on image-caption pairs with a task to learn the perfect embeddings mapping for images and text. CLIP is composed of a text encoder and an image encoder that produce embeddings in a shared space. The learned embeddings are similar when the concepts (image or text) closer together, and not similar for unrelated concepts.

Furthermore, these embeddings capture rich semantic information, which turns out to be very useful for many downstream tasks such as image generation, or zero-shot image classification.

In this article, we will see how to evalute the performance of CLIP on the task of zero-shot image classification on 3 datasets with variying difficuly.

As depicted in the diagram above (credit), CLIP is used for classification as follows:

  1. For each label in our classification labels we generate a prompt text like this 'a photo of a {label}'
  2. We embed these prompts with the CLIP text encoder.
  3. We embed the images with the CLIP image encoder.
  4. Using cosine similarity find the best match for the image embeddings from all of the prompts embeddings.
  5. Optionally use Softmax to convert this to a probablity

Setup

We will use Flax-implementation of CLIP available from the transformers library. So let's install our dependencies.

%%capture
%%bash

pip install --upgrade flax transformers
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
import jax
import jax.numpy as jnp
from sklearn import metrics
from tqdm.auto import tqdm
from fastai.vision.all import *
from transformers import CLIPProcessor, FlaxCLIPModel

Set a seed for reproducibility

seed = 123
random.seed(seed)

We will use the checkpoints for CLIP model and inputs processing of OpenAI's clip-vit-base-patch32 availble in Hugging Face.

model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

Helper functions

We need couple of helper functions that we will use in the classification sections.

First, the following fuctions scans for image files on disk using a regex pattern. And optionally a label dictionary used in case we want use something else as label instead of labels extracted from the file's path. For instance, imagenet labels are numbers and it would be better to map them to class text.

def scan_images(pattern, class_to_label=None):
    files = glob.glob(pattern)
    random.shuffle(files)
    print(f'Found {len(files)} files')
    labels = [f.split('/')[-2] for f in files]
    if class_to_label:
        labels = [class_to_label[cls] for cls in labels]
    print('First few labels:', labels[:3])
    print('First few filenames:', files[:3])
    return files, labels

The following functions performs classifcation prediction on batches of the image files. It first creates text prompts of the form "a photo of a {label}" using the labels. Then, uses the CLIPProcessor to prepare those prompts and the images before passing them to the FlaxCLIPModel for predictions. The output of the model FlaxCLIPModel is passed through a softmax to calculate the probabilities of each label, the highest probablity is used for picking the predicted label.

def predict(files, classes, batch_size=256):
  texts = [f'a photo of a {cl}' for cl in classes]
  y_pred = jnp.asarray([])
  
  for start in tqdm(range(0, len(files), batch_size)):
    end = min(start+batch_size, len(files))
    # read each image
    images = [Image.open(f) for f in files[start:end]]
    # pre-process the texts and images
    inputs = processor(
      text=texts, images=images, return_tensors="np", padding=True
    )
    # run CLIP
    outputs = model(**inputs)
    # get the image-text similarity score
    logits_per_image = outputs.logits_per_image
    # apply softmax to get the label probabilities
    probs = jax.nn.softmax(logits_per_image, axis=1)
    y_pred = jnp.append(y_pred, np.argmax(probs, axis=-1))

  return y_pred

There will be some mis-classified images, the following helper function select some of those images and load them from disk into PIL.

def select_misclassified(y_true, y_pred, classes, files, k=9):
    indecies = jnp.where((y_true==y_pred)==False)[0]
    indecies = [int(idx) for idx in list(indecies)]
    indecies_k9 = random.sample(indecies, k=k)
    pil_files = [files[i] for i in indecies_k9]
    pil_labels = [f'{classes[int(y_pred[i])]}/{classes[int(y_true[i])]}' for i in indecies_k9]
    pil_images = [Image.open(f) for f in pil_files]
    return pil_images, pil_labels

The following helper function will be used to plot a collection of images along with their descriptions.

def ceildiv(a, b):
  return -(-a // b)

def plots_pil_images(pil_images, figsize=(10,5), rows=1, cols=None, titles=None, maintitle=None):
  f = plt.figure(figsize=figsize)
  if maintitle is not None: plt.suptitle(maintitle, fontsize=10)
  cols = cols if cols else ceildiv(len(pil_images), rows)
  for i in range(len(pil_images)):
    sp = f.add_subplot(rows, cols, i+1)
    sp.axis('Off')
    if titles is not None: sp.set_title(titles[i], fontsize=16)
    img = np.asarray(pil_images[i])
    plt.imshow(img)

Finally, a helper function to download an imge from the internet

def get_image(url):
    headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"}
    resp = requests.get(url, headers = headers)
    return Image.open(io.BytesIO(resp.content))

Image Classification

In this section, we will try zero-shot classfication on the following datasets starting from the easiest one to the more challenging. For each one we will evaluate the model performance.

Dogs vs Cats

Dogs vs Cats is a binary classification dataset, let's download it and use the images from the validation set to run our predictions.

dogs_path = untar_data(URLs.DOGS)
100.00% [839286784/839285364 00:18<00:00]

Load the files and true labels of this dataset

files1, labels1 = scan_images(f'{dogs_path}/valid/*/*.jpg')
Found 2000 files
First few labels: ['dogs', 'cats', 'dogs']
First few filenames: ['/root/.fastai/data/dogscats/valid/dogs/dog.6312.jpg', '/root/.fastai/data/dogscats/valid/cats/cat.6086.jpg', '/root/.fastai/data/dogscats/valid/dogs/dog.6427.jpg']

Replace the true labels with the respective class index dogs -> 0 and cats -> 1.

classes1 = ['dogs', 'cats']
y_true1 = jnp.asarray([classes1.index(l) for l in labels1])
y_pred1 = predict(files1, classes1)

Let's calculate the accuracy score of our predictions.

metrics.accuracy_score(y_true1, y_pred1)
0.995

Let's calculate the confusion matrix our predictions to find out where our model got the labels wrong.

metrics.ConfusionMatrixDisplay.from_predictions(y_true1, y_pred1, display_labels=classes1, cmap='Blues');

Note: see how CLIP performed zero-shot classification on this dataset with 99.5% accuracy!

Let's see some of the images that the model got wrong.

pil_images1, pil_labels1 = select_misclassified(y_true1, y_pred1, classes1, files1)
plots_pil_images(pil_images1, figsize=(13, 7), rows=3, titles=pil_labels1, maintitle='classifications y_pred/y_true')

IMAGENETTE_160

As a next dataset to try, we pick Imagenette, which is a subset of the Imagenet dataset containing only 10 classes. Let's download the dataset.

path2 = untar_data(URLs.IMAGENETTE_160)
imagenet_class_to_class_label = {
    'n01440764': 'tench',
    'n02102040': 'English springer',
    'n02979186': 'cassette player',
    'n03000684': 'chain saw',
    'n03028079': 'church',
    'n03394916': 'French horn',
    'n03417042': 'garbage truck',
    'n03425413': 'gas pump',
    'n03445777': 'golf ball',
    'n03888257': 'parachute'
}
classes2 = list(imagenet_class_to_class_label.values())
print(classes2)
100.01% [99008512/99003388 00:02<00:00]
['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']

Next, we scan for the images and their labels from the validation set.

files2, labels2 = scan_images(f'{path2}/val/*/*.JPEG', imagenet_class_to_class_label)
Found 3925 files
First few labels: ['golf ball', 'parachute', 'English springer']
First few filenames: ['/root/.fastai/data/imagenette2-160/val/n03445777/n03445777_11822.JPEG', '/root/.fastai/data/imagenette2-160/val/n03888257/n03888257_37950.JPEG', '/root/.fastai/data/imagenette2-160/val/n02102040/n02102040_2890.JPEG']

Then, we get the predictions from CLIP

y_true2 = jnp.asarray([classes2.index(l) for l in labels2])
y_pred2 = predict(files2, classes2)

Let's calculate the model accuracy

metrics.accuracy_score(y_true2, y_pred2)
0.9859872611464968

Then find out what labels got wrongly classified using the confusing matrix.

metrics.ConfusionMatrixDisplay.from_predictions(
    y_true2,
    y_pred2,
    display_labels=classes2,
    xticks_rotation='vertical',
    cmap='Blues'
    );

Note: you can see that CLIP did a good job in this dataset too with an accuray of 98.5%

Let's visualize some of the iamges the model classified wrongly.

pil_images2, pil_labels2 = select_misclassified(y_true2, y_pred2, classes2, files2)
plots_pil_images(pil_images2, figsize=(11, 8), rows=3, titles=pil_labels2, maintitle='classifications y_pred/y_true')

You can see that those images are confusing, for intance the picture of the man playing french horn in a church. The model picked church over french horn but both are in fact accurate classification.

Dog breads

Moving to a harder dataset Imagewoof which consists of images of 10 hard to tell appart dog breeds.

path3 = untar_data(URLs.IMAGEWOOF_160)

imagenet_class_to_dogbread = {
    'n02086240': 'Shih-Tzu',
    'n02087394': 'Rhodesian ridgeback',
    'n02088364': 'Beagle',
    'n02089973': 'English foxhound',
    'n02093754': 'Border terrier',
    'n02096294': 'Australian terrier',
    'n02099601': 'Golden retriever',
    'n02105641': 'Old English sheepdog',
    'n02111889': 'Samoyed',
    'n02115641': 'Dingo'
}
classes3 = list(imagenet_class_to_dogbread.values())
print(classes3)
100.01% [92618752/92612825 00:02<00:00]
['Shih-Tzu', 'Rhodesian ridgeback', 'Beagle', 'English foxhound', 'Border terrier', 'Australian terrier', 'Golden retriever', 'Old English sheepdog', 'Samoyed', 'Dingo']

After downloading the dataset, let's scan the images in the validation set folder.

files3, labels3 = scan_images(f'{path3}/val/*/*.JPEG', imagenet_class_to_dogbread)
Found 3929 files
First few labels: ['Samoyed', 'Beagle', 'Beagle']
First few filenames: ['/root/.fastai/data/imagewoof2-160/val/n02111889/n02111889_6962.JPEG', '/root/.fastai/data/imagewoof2-160/val/n02088364/n02088364_12710.JPEG', '/root/.fastai/data/imagewoof2-160/val/n02088364/n02088364_6092.JPEG']

Run CLIP to predict a label for each image

y_true3 = jnp.asarray([classes3.index(l) for l in labels3])
y_pred3 = predict(files3, classes3)

Next we get the accuracy

metrics.accuracy_score(y_true3, y_pred3)
0.8819037923135657

Then we calculate the confusion matrix to see any interesting mis-classification

metrics.ConfusionMatrixDisplay.from_predictions(
    y_true3,
    y_pred3,
    display_labels=classes3,
    xticks_rotation='vertical',
    cmap='Blues'
    );

Notice how CLIP had harder time classifying Bealges from English foxhound. In fact, when looking to images of these two breeds it is really hard to tell which one is which as you can see in the following pictures.

img1 = get_image('https://upload.wikimedia.org/wikipedia/commons/b/b7/Beagle_Faraon.JPG')
img2 = get_image('https://d17fnq9dkz9hgj.cloudfront.net/breed-uploads/2018/08/english-foxhound-detail.jpg')

plots_pil_images([img1, img2], rows=1, titles=['Beagle', 'English foxhound'])

Overall CLIP did not very bad on this harder dataset performing a shy accuracy of 88.2%. Let's visualize some of the images the model did mis-classify.

pil_images3, pil_labels3 = select_misclassified(y_true3, y_pred3, classes3, files3)
plots_pil_images(pil_images3, figsize=(13, 8), rows=3, titles=pil_labels3, maintitle='classifications y_pred/y_true')

That's all folks

CLIP is very powerful and can be used for many different tasks from data filtering/search to image generation. In this post, we saw how to use CLIP to perform image classification without having to fine-tune the model.

I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.