Textual Inversion was first introduced in An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion - website. This technique tries to find new embeddings that represent user-provided visual concepts (e.g. image of an object, or hand drawings). These embeddings are then linked to new pseudo-words (the paper uses the $S_{*}$ term) which can be incorporated into typical prompts. Example of prompts including a new concept:
- "A carpet with $S{*}$ embroidery"_
- "A stained glass window depicting $S{*}$"_
- "Painting of $S{*}$ in the style of Monet"_
Surprisingly, the training of this technique involves the use of 3 to 5 images to teach models like Stable Diffusion to use the new concept for personalized image generation.
In this article, we will adapt the FlaxStableDiffusionPipeline
to use new concepts when generating images from text. The concepts we will use are downloaded from the Stable Diffusion Textual Inversion Concepts Library which has tons of different publically available concepts. Feel free to browse through this library and choose other concepts to play with.
First, we need to install some libraries include diffusers
and Flax
.
%%capture
%%bash
pip install --upgrade scipy flax
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade diffusers transformers
jax[cuda]
and instead just connect JAX to the TPU machines as described in this article.
Log into your Hugging Face account with an access token and accept the terms of the licence for this model - see model card.
!huggingface-cli login
Import the libraries we will be using.
import os
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
%matplotlib inline
Now, we can instanciate a Flax Diffusion pipeline and load the propoer weights for precision.
dtype = jax.numpy.bfloat16
model_id = "CompVis/stable-diffusion-v1-4"
revision = "bf16" # "flax"
from diffusers import FlaxStableDiffusionPipeline
pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=dtype
)
To introduce new components to our pipeline, we need to add new tokens and link them to their respective embeddings. This is done by acting on the pipeline's tokenizer
and text_encoder
.
tokenizer = pipeline.tokenizer
text_encoder = pipeline.text_encoder
At this stage, the text embeddings are available in pipeline_params
and will be loaded by the pipeline later into the text_encoder
later.
text_encoder_params = pipeline_params["text_encoder"]
All of the concepts embbedings we will be using here come from Hugging Face's Concepts Library. So let's define a helper function to download all the necessary files for a given concept.
def download_embeddings(repo_id):
embeds_path = hf_hub_download(repo_id=repo_id, filename=f"learned_embeds.bin")
token_path = hf_hub_download(repo_id=repo_id, filename=f"token_identifier.txt")
return embeds_path, token_path
Let's download the embeddings for the <cat-toy>
concept.
repo_id_cat_toy = "sd-concepts-library/cat-toy"
embeds_path, token_path = download_embeddings(repo_id_cat_toy, "cat_toy")
Because the original embeddings were stored with pytorch, we cannot just load them into a JAX array but instead a conversion step is needed. One simpler way to convert the embeddings to JAX is to:
- load the embeddings with pytorch,
- convert them to a numpy array
- read the numpy array with JAX
The following helper function do just that:
import torch
def load_embeds(embeds_path):
pytorch_params = torch.load(embeds_path, map_location='cpu')
trained_token = list(pytorch_params.keys())[0]
pytorch_embeds = pytorch_params[trained_token]
embeds = jnp.asarray(pytorch_embeds.numpy()).astype(dtype)
return trained_token, embeds
Let's load the embeddings
token, embeds = load_embeds(embeds_path)
First, let's check the current size of the tokenizer vocabulary before adding our new token
len(tokenizer)
To add a token we simply do
tokenizer.add_tokens(token)
Now let's confirm that the vocabulary size increased by one after we added our new token
len(tokenizer)
We can get the ID associated to our new token just to confirm everything is fine so far
tokenizer.convert_tokens_to_ids(token)
We are not done yet with adding our concepts. After, adding new tokens to the vocabulary, we need to make sure to also resize the token embedding matrix of the text_encoder
model to match vocabulary size.
The following helper function:
- updates the vocabulary size in the
text_encoder
model configruation and also - adds the embeddings of the concept to the
text_encoder
's token embeddings matrix,
def add_token_embeddings(token_embeds, text_encoder_params):
if len(token_embeds.shape) == 1:
token_embeds = jnp.expand_dims(token_embeds, axis=0)
# update vocab size
text_encoder._config.vocab_size = len(tokenizer)
# retrive the token embeddings from the encoder parameters
text_model_embeds = text_encoder_params['text_model']['embeddings']
token_embedding = text_model_embeds['token_embedding']
text_encoder_embeds = token_embedding['embedding']
# append the new embeddings
new_text_encoder_embeds = jnp.append(text_encoder_embeds, token_embeds, axis=0)
token_embedding['embedding'] = new_text_encoder_embeds
Let's add the embeddings for the <cat-toy>
concept
add_token_embeddings(embeds, text_encoder_params)
To make sure nothing is broken, we can test that we can successfully tokenize and encode a text that contains our new concept.
text_input = tokenizer(
token,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
prompt_ids = text_input.input_ids
text_embeddings = text_encoder(prompt_ids, params=text_encoder_params)[0]
text_embeddings.shape
Generating images
The overall logic for using the pipeline to generate images from prompts is pretty much the same one we used in the introduction example, regarless of using the new concept in the prompt or not.
We first use pipeline.prepare_inputs
to convert the text prompt to tokens, then get their embeddings. After that, we sample from random noise our initial latent representation. And call the pipeline in a loop as many times as num_inference_steps
to denoise that initial latent representation. Finally, we use pipeline.numpy_to_pil
to convert the JAX array into an actual PIL image.
guidance_scale = 8 #@param {type:"slider", min:0, max:100, step:0.5}
num_inference_steps = 30 #@param
def generate(prompt, seed=0, num_inference_steps=50, guidance_scale=7.5):
num_samples = jax.device_count()
if not isinstance(prompt, list):
prompt = num_samples * [prompt]
else:
assert num_samples == len(prompt)
prng_seed = jax.random.PRNGKey(seed)
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids = shard(prompt_ids)
# shard inputs and rng
params = replicate(pipeline_params)
prng_seed = jax.random.split(prng_seed, num_samples)
output = pipeline(prompt_ids, params, prng_seed, num_inference_steps, guidance_scale=guidance_scale, jit=True)
images = output.images
images = np.asarray(images.reshape((num_samples,) + images.shape[-3:]))
images = pipeline.numpy_to_pil(images)
return images
Let's use our <cat-toy>
concept in a prompt and see how it is looks in the generated image.
%%time
prompt = "an oil painting of a <cat-toy> in a town by the river in Andalucia"
images = generate(prompt, seed=0)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
image_grid(images, 1, 1)
As you can see getting a prompt to generate the right image we want is hard. The result have a silouete of cat but not the cat toy we expects. We have to try something different, in fact, some objects may need the concept at the begining, other styles work better at the end.
One simple thing we could try is to simplify the prompt as follows:
%%time
prompt = "<cat-toy> in Seville"
images = generate(prompt, seed=0)
image_grid(images, 1, 1)
The result is kind better as we have something close to our cat toy object but the image looks boring. We probably need to try modifying the prompt, for inspiration it is good to browse prompts used in lexica.art or playgroundai.com.
Now, let's try playing with other concepts from the Stable Diffusion Textual Inversion Concepts Library to generate images of different styles:
styles = ['birb-style', 'midjourney-style', 'style-of-marc-allante']
for style in styles:
embeds_path, token_path1 = download_embeddings(f'sd-concepts-library/{style}', style)
token, embeds = load_embeds(embeds_path)
tokenizer.add_tokens(token)
add_token_embeddings(embeds, text_encoder_params)
%%time
images = []
prompts = [f"an oil painting of a town by the river in Andalucia in the style of <{style}>" for style in styles]
for prompt in prompts:
images = images + generate(prompt, seed=0)
image_grid(images, 1, 3)
That's all folks
Stable Diffusion is a very neat model, it allows us to generate fantastic pictures with a simple text prompt. It is very flexible and can be customized so it does not generate random images. In this article, we saw one way of conditioning the model output called Textual Inversion. Furthermore, we saw that this technique is easy to implement in Flax with the diffusers
library and leveraging many of the public concepts available on Hugging Face's Concept Library.
I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.