Textual Inversion was first introduced in An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion - website. This technique tries to find new embeddings that represent user-provided visual concepts (e.g. image of an object, or hand drawings). These embeddings are then linked to new pseudo-words (the paper uses the $S_{*}$ term) which can be incorporated into typical prompts. Example of prompts including a new concept:

  • "A carpet with $S{*}$ embroidery"_
  • "A stained glass window depicting $S{*}$"_
  • "Painting of $S{*}$ in the style of Monet"_

Surprisingly, the training of this technique involves the use of 3 to 5 images to teach models like Stable Diffusion to use the new concept for personalized image generation.

In this article, we will adapt the FlaxStableDiffusionPipeline to use new concepts when generating images from text. The concepts we will use are downloaded from the Stable Diffusion Textual Inversion Concepts Library which has tons of different publically available concepts. Feel free to browse through this library and choose other concepts to play with.

Setup and imports

First, we need to install some libraries include diffusers and Flax.

%%capture
%%bash

pip install --upgrade scipy flax
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade diffusers transformers

Note: if you can have access to TPU then you should skip installing jax[cuda] and instead just connect JAX to the TPU machines as described in this article.

Log into your Hugging Face account with an access token and accept the terms of the licence for this model - see model card.

!huggingface-cli login
    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` now requires a token generated from https://huggingface.co/settings/tokens .
    
Token: 
Add token as git credential? (Y/n) n
Token is valid.
Your token has been saved to /root/.huggingface/token
Login successful

Import the libraries we will be using.

import os
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
%matplotlib inline

Loading the pipeline

Now, we can instanciate a Flax Diffusion pipeline and load the propoer weights for precision.

dtype = jax.numpy.bfloat16
model_id = "CompVis/stable-diffusion-v1-4"
revision = "bf16" # "flax"
from diffusers import FlaxStableDiffusionPipeline

pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=dtype
)

To introduce new components to our pipeline, we need to add new tokens and link them to their respective embeddings. This is done by acting on the pipeline's tokenizer and text_encoder.

tokenizer = pipeline.tokenizer
text_encoder = pipeline.text_encoder

At this stage, the text embeddings are available in pipeline_params and will be loaded by the pipeline later into the text_encoder later.

text_encoder_params = pipeline_params["text_encoder"]

Adding concepts

All of the concepts embbedings we will be using here come from Hugging Face's Concepts Library. So let's define a helper function to download all the necessary files for a given concept.

Note: you can use your own concepts or get them from a different source. They just need to be available locally when we load them later.

def download_embeddings(repo_id):
    embeds_path = hf_hub_download(repo_id=repo_id, filename=f"learned_embeds.bin")
    token_path = hf_hub_download(repo_id=repo_id, filename=f"token_identifier.txt")
    return embeds_path, token_path

Let's download the embeddings for the <cat-toy> concept.

repo_id_cat_toy = "sd-concepts-library/cat-toy"
embeds_path, token_path = download_embeddings(repo_id_cat_toy, "cat_toy")

Because the original embeddings were stored with pytorch, we cannot just load them into a JAX array but instead a conversion step is needed. One simpler way to convert the embeddings to JAX is to:

  1. load the embeddings with pytorch,
  2. convert them to a numpy array
  3. read the numpy array with JAX

The following helper function do just that:

import torch
def load_embeds(embeds_path):
    pytorch_params = torch.load(embeds_path, map_location='cpu')
    trained_token = list(pytorch_params.keys())[0]
    pytorch_embeds = pytorch_params[trained_token]
    embeds = jnp.asarray(pytorch_embeds.numpy()).astype(dtype)
    return trained_token, embeds

Let's load the embeddings

token, embeds = load_embeds(embeds_path)

First, let's check the current size of the tokenizer vocabulary before adding our new token

len(tokenizer)

To add a token we simply do

tokenizer.add_tokens(token)

Now let's confirm that the vocabulary size increased by one after we added our new token

len(tokenizer)

We can get the ID associated to our new token just to confirm everything is fine so far

tokenizer.convert_tokens_to_ids(token)

We are not done yet with adding our concepts. After, adding new tokens to the vocabulary, we need to make sure to also resize the token embedding matrix of the text_encoder model to match vocabulary size.

The following helper function:

  • updates the vocabulary size in the text_encoder model configruation and also
  • adds the embeddings of the concept to the text_encoder's token embeddings matrix,
def add_token_embeddings(token_embeds, text_encoder_params):
    if len(token_embeds.shape) == 1:
        token_embeds = jnp.expand_dims(token_embeds, axis=0)
    # update vocab size
    text_encoder._config.vocab_size = len(tokenizer)
    # retrive the token embeddings from the encoder parameters
    text_model_embeds = text_encoder_params['text_model']['embeddings']
    token_embedding = text_model_embeds['token_embedding']
    text_encoder_embeds = token_embedding['embedding']
    # append the new embeddings
    new_text_encoder_embeds = jnp.append(text_encoder_embeds, token_embeds, axis=0)
    token_embedding['embedding'] = new_text_encoder_embeds
    

Let's add the embeddings for the <cat-toy> concept

add_token_embeddings(embeds, text_encoder_params)

To make sure nothing is broken, we can test that we can successfully tokenize and encode a text that contains our new concept.

text_input = tokenizer(
            token,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="np",
        )
prompt_ids = text_input.input_ids

text_embeddings = text_encoder(prompt_ids, params=text_encoder_params)[0]

text_embeddings.shape

Generating images

The overall logic for using the pipeline to generate images from prompts is pretty much the same one we used in the introduction example, regarless of using the new concept in the prompt or not.

We first use pipeline.prepare_inputs to convert the text prompt to tokens, then get their embeddings. After that, we sample from random noise our initial latent representation. And call the pipeline in a loop as many times as num_inference_steps to denoise that initial latent representation. Finally, we use pipeline.numpy_to_pil to convert the JAX array into an actual PIL image.

guidance_scale = 8 #@param {type:"slider", min:0, max:100, step:0.5}
num_inference_steps = 30 #@param
def generate(prompt, seed=0, num_inference_steps=50, guidance_scale=7.5):
    num_samples = jax.device_count()
    if not isinstance(prompt, list):
        prompt = num_samples * [prompt]
    else:
        assert num_samples == len(prompt)
    prng_seed = jax.random.PRNGKey(seed)
    
    prompt_ids = pipeline.prepare_inputs(prompt)
    prompt_ids = shard(prompt_ids)

    # shard inputs and rng
    params = replicate(pipeline_params)
    prng_seed = jax.random.split(prng_seed, num_samples)

    output = pipeline(prompt_ids, params, prng_seed, num_inference_steps, guidance_scale=guidance_scale, jit=True)
    images = output.images
    images = np.asarray(images.reshape((num_samples,) + images.shape[-3:]))
    images = pipeline.numpy_to_pil(images)
    return images

Let's use our <cat-toy> concept in a prompt and see how it is looks in the generated image.

%%time
prompt = "an oil painting of a <cat-toy> in a town by the river in Andalucia"
images = generate(prompt, seed=0)
CPU times: user 2min 22s, sys: 1.9 s, total: 2min 24s
Wall time: 2min 19s
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid
image_grid(images, 1, 1)

As you can see getting a prompt to generate the right image we want is hard. The result have a silouete of cat but not the cat toy we expects. We have to try something different, in fact, some objects may need the concept at the begining, other styles work better at the end.

One simple thing we could try is to simplify the prompt as follows:

%%time
prompt = "<cat-toy> in Seville"
images = generate(prompt, seed=0)
CPU times: user 32.1 s, sys: 216 ms, total: 32.3 s
Wall time: 31.5 s
image_grid(images, 1, 1)

The result is kind better as we have something close to our cat toy object but the image looks boring. We probably need to try modifying the prompt, for inspiration it is good to browse prompts used in lexica.art or playgroundai.com.

Now, let's try playing with other concepts from the Stable Diffusion Textual Inversion Concepts Library to generate images of different styles:

styles = ['birb-style', 'midjourney-style', 'style-of-marc-allante']

for style in styles:
    embeds_path, token_path1 = download_embeddings(f'sd-concepts-library/{style}', style)
    token, embeds = load_embeds(embeds_path)
    tokenizer.add_tokens(token)
    add_token_embeddings(embeds, text_encoder_params)
%%time
images = []
prompts = [f"an oil painting of a town by the river in Andalucia in the style of <{style}>" for style in styles]
for prompt in prompts:
    images = images + generate(prompt, seed=0)
CPU times: user 2min 25s, sys: 1.27 s, total: 2min 26s
Wall time: 2min 20s
image_grid(images, 1, 3)

That's all folks

Stable Diffusion is a very neat model, it allows us to generate fantastic pictures with a simple text prompt. It is very flexible and can be customized so it does not generate random images. In this article, we saw one way of conditioning the model output called Textual Inversion. Furthermore, we saw that this technique is easy to implement in Flax with the diffusers library and leveraging many of the public concepts available on Hugging Face's Concept Library.

I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.