CLIP (Contrastive Language-Image Pre-training) was created by OpenAI. This model is trained on image-caption pairs with a task to learn the perfect embeddings mapping for images and text. CLIP is composed of a text encoder and an image encoder that produce embeddings in a shared space. The learned embeddings are similar when the concepts (image or text) closer together, and not similar for unrelated concepts.
Furthermore, these embeddings capture rich semantic information, which turns out to be very useful for many downstream tasks such as image generation, or zero-shot image classification.
In this article, we will see how to evalute the performance of CLIP on the task of zero-shot image classification on 3 datasets with variying difficuly.
As depicted in the diagram above (credit), CLIP is used for classification as follows:
- For each label in our classification labels we generate a prompt text like this 'a photo of a {label}'
- We embed these prompts with the CLIP text encoder.
- We embed the images with the CLIP image encoder.
- Using cosine similarity find the best match for the image embeddings from all of the prompts embeddings.
- Optionally use Softmax to convert this to a probablity
We will use Flax-implementation of CLIP available from the transformers library. So let's install our dependencies.
%%capture
%%bash
pip install --upgrade flax transformers
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
import jax
import jax.numpy as jnp
from sklearn import metrics
from tqdm.auto import tqdm
from fastai.vision.all import *
from transformers import CLIPProcessor, FlaxCLIPModel
Set a seed for reproducibility
seed = 123
random.seed(seed)
We will use the checkpoints for CLIP model and inputs processing of OpenAI's clip-vit-base-patch32 availble in Hugging Face.
model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
First, the following fuctions scans for image files on disk using a regex pattern. And optionally a label dictionary used in case we want use something else as label instead of labels extracted from the file's path. For instance, imagenet labels are numbers and it would be better to map them to class text.
def scan_images(pattern, class_to_label=None):
    files = glob.glob(pattern)
    random.shuffle(files)
    print(f'Found {len(files)} files')
    labels = [f.split('/')[-2] for f in files]
    if class_to_label:
        labels = [class_to_label[cls] for cls in labels]
    print('First few labels:', labels[:3])
    print('First few filenames:', files[:3])
    return files, labels
The following functions performs classifcation prediction on batches of the image files. It first creates text prompts of the form "a photo of a {label}" using the labels. Then, uses the CLIPProcessor to prepare those prompts and the images before passing them to the FlaxCLIPModel for predictions. The output of the model FlaxCLIPModel is passed through a softmax to calculate the probabilities of each label, the highest probablity is used for picking the predicted label.
def predict(files, classes, batch_size=256):
  texts = [f'a photo of a {cl}' for cl in classes]
  y_pred = jnp.asarray([])
  
  for start in tqdm(range(0, len(files), batch_size)):
    end = min(start+batch_size, len(files))
    # read each image
    images = [Image.open(f) for f in files[start:end]]
    # pre-process the texts and images
    inputs = processor(
      text=texts, images=images, return_tensors="np", padding=True
    )
    # run CLIP
    outputs = model(**inputs)
    # get the image-text similarity score
    logits_per_image = outputs.logits_per_image
    # apply softmax to get the label probabilities
    probs = jax.nn.softmax(logits_per_image, axis=1)
    y_pred = jnp.append(y_pred, np.argmax(probs, axis=-1))
  return y_pred
There will be some mis-classified images, the following helper function select some of those images and load them from disk into PIL.
def select_misclassified(y_true, y_pred, classes, files, k=9):
    indecies = jnp.where((y_true==y_pred)==False)[0]
    indecies = [int(idx) for idx in list(indecies)]
    indecies_k9 = random.sample(indecies, k=k)
    pil_files = [files[i] for i in indecies_k9]
    pil_labels = [f'{classes[int(y_pred[i])]}/{classes[int(y_true[i])]}' for i in indecies_k9]
    pil_images = [Image.open(f) for f in pil_files]
    return pil_images, pil_labels
The following helper function will be used to plot a collection of images along with their descriptions.
def ceildiv(a, b):
  return -(-a // b)
def plots_pil_images(pil_images, figsize=(10,5), rows=1, cols=None, titles=None, maintitle=None):
  f = plt.figure(figsize=figsize)
  if maintitle is not None: plt.suptitle(maintitle, fontsize=10)
  cols = cols if cols else ceildiv(len(pil_images), rows)
  for i in range(len(pil_images)):
    sp = f.add_subplot(rows, cols, i+1)
    sp.axis('Off')
    if titles is not None: sp.set_title(titles[i], fontsize=16)
    img = np.asarray(pil_images[i])
    plt.imshow(img)
Finally, a helper function to download an imge from the internet
def get_image(url):
    headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"}
    resp = requests.get(url, headers = headers)
    return Image.open(io.BytesIO(resp.content))
dogs_path = untar_data(URLs.DOGS)
Load the files and true labels of this dataset
files1, labels1 = scan_images(f'{dogs_path}/valid/*/*.jpg')
Replace the true labels with the respective class index dogs -> 0 and cats -> 1.
classes1 = ['dogs', 'cats']
y_true1 = jnp.asarray([classes1.index(l) for l in labels1])
y_pred1 = predict(files1, classes1)
Let's calculate the accuracy score of our predictions.
metrics.accuracy_score(y_true1, y_pred1)
Let's calculate the confusion matrix our predictions to find out where our model got the labels wrong.
metrics.ConfusionMatrixDisplay.from_predictions(y_true1, y_pred1, display_labels=classes1, cmap='Blues');
Let's see some of the images that the model got wrong.
pil_images1, pil_labels1 = select_misclassified(y_true1, y_pred1, classes1, files1)
plots_pil_images(pil_images1, figsize=(13, 7), rows=3, titles=pil_labels1, maintitle='classifications y_pred/y_true')
IMAGENETTE_160
As a next dataset to try, we pick Imagenette, which is a subset of the Imagenet dataset containing only 10 classes. Let's download the dataset.
path2 = untar_data(URLs.IMAGENETTE_160)
imagenet_class_to_class_label = {
    'n01440764': 'tench',
    'n02102040': 'English springer',
    'n02979186': 'cassette player',
    'n03000684': 'chain saw',
    'n03028079': 'church',
    'n03394916': 'French horn',
    'n03417042': 'garbage truck',
    'n03425413': 'gas pump',
    'n03445777': 'golf ball',
    'n03888257': 'parachute'
}
classes2 = list(imagenet_class_to_class_label.values())
print(classes2)
Next, we scan for the images and their labels from the validation set.
files2, labels2 = scan_images(f'{path2}/val/*/*.JPEG', imagenet_class_to_class_label)
Then, we get the predictions from CLIP
y_true2 = jnp.asarray([classes2.index(l) for l in labels2])
y_pred2 = predict(files2, classes2)
Let's calculate the model accuracy
metrics.accuracy_score(y_true2, y_pred2)
Then find out what labels got wrongly classified using the confusing matrix.
metrics.ConfusionMatrixDisplay.from_predictions(
    y_true2,
    y_pred2,
    display_labels=classes2,
    xticks_rotation='vertical',
    cmap='Blues'
    );
Let's visualize some of the iamges the model classified wrongly.
pil_images2, pil_labels2 = select_misclassified(y_true2, y_pred2, classes2, files2)
plots_pil_images(pil_images2, figsize=(11, 8), rows=3, titles=pil_labels2, maintitle='classifications y_pred/y_true')
You can see that those images are confusing, for intance the picture of the man playing french horn in a church. The model picked church over french horn but both are in fact accurate classification.
Dog breads
Moving to a harder dataset Imagewoof which consists of images of 10 hard to tell appart dog breeds.
path3 = untar_data(URLs.IMAGEWOOF_160)
imagenet_class_to_dogbread = {
    'n02086240': 'Shih-Tzu',
    'n02087394': 'Rhodesian ridgeback',
    'n02088364': 'Beagle',
    'n02089973': 'English foxhound',
    'n02093754': 'Border terrier',
    'n02096294': 'Australian terrier',
    'n02099601': 'Golden retriever',
    'n02105641': 'Old English sheepdog',
    'n02111889': 'Samoyed',
    'n02115641': 'Dingo'
}
classes3 = list(imagenet_class_to_dogbread.values())
print(classes3)
After downloading the dataset, let's scan the images in the validation set folder.
files3, labels3 = scan_images(f'{path3}/val/*/*.JPEG', imagenet_class_to_dogbread)
Run CLIP to predict a label for each image
y_true3 = jnp.asarray([classes3.index(l) for l in labels3])
y_pred3 = predict(files3, classes3)
Next we get the accuracy
metrics.accuracy_score(y_true3, y_pred3)
Then we calculate the confusion matrix to see any interesting mis-classification
metrics.ConfusionMatrixDisplay.from_predictions(
    y_true3,
    y_pred3,
    display_labels=classes3,
    xticks_rotation='vertical',
    cmap='Blues'
    );
Notice how CLIP had harder time classifying Bealges from English foxhound. In fact, when looking to images of these two breeds it is really hard to tell which one is which as you can see in the following pictures.
img1 = get_image('https://upload.wikimedia.org/wikipedia/commons/b/b7/Beagle_Faraon.JPG')
img2 = get_image('https://d17fnq9dkz9hgj.cloudfront.net/breed-uploads/2018/08/english-foxhound-detail.jpg')
plots_pil_images([img1, img2], rows=1, titles=['Beagle', 'English foxhound'])
Overall CLIP did not very bad on this harder dataset performing a shy accuracy of 88.2%. Let's visualize some of the images the model did mis-classify.
pil_images3, pil_labels3 = select_misclassified(y_true3, y_pred3, classes3, files3)
plots_pil_images(pil_images3, figsize=(13, 8), rows=3, titles=pil_labels3, maintitle='classifications y_pred/y_true')
That's all folks
CLIP is very powerful and can be used for many different tasks from data filtering/search to image generation. In this post, we saw how to use CLIP to perform image classification without having to fine-tune the model.
I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.