Stable Diffusion is a powerful model that can used to generate an image form random text prompts. Hugging Face's diffusers library provides access to such model in a very convenient way.

In this article we will use this library to load the Flax-based Stable Diffusion model and its weights to generate a sequence of interpolated images.

Setup and Imports

First, we need to install some libraries include diffusers and JAX/Flax.

Note: we will use JAX with GPU backend, alternatively you can use TPU backend by initializing JAX like this
import jax.tools.colab_tpu jax.tools.colab_tpu.setup_tpu()

%%capture
%%bash

pip install --upgrade scipy flax
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade diffusers transformers

To be able to use Stable Diffusion we need to login to Hugging Face and accept the terms of the licence for this model.

Note: to login you will need a Token that you can get from the settings page of your Hugging Face account.

!huggingface-cli login
    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` now requires a token generated from https://huggingface.co/settings/tokens .
    
Token: 
Add token as git credential? (Y/n) n
Token is valid.
Your token has been saved to /root/.huggingface/token
Login successful
import os
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

Generating many images with Stable Diffusion in Colab can take time, so to save progress if we had to restart we need to mount Google Drive (or other persistent storage).

from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
Mounted at /content/drive/
PATH = '/content/drive/My Drive/Colab Notebooks/diffusion'

Model and components

Now, we can instanciate a Flax Diffusion pipeline and load the propoer weights.

from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel, FlaxPNDMScheduler
dtype = jax.numpy.float16

We will use the checkpoints of Stable Diffusion from CompVis/stable-diffusion-v1-4 with half-precision (i.e. revision bf16) to save on memory and avoid issues like OOM.

Note: if you want to use float32 precision then replace revision with flax

model_id = "CompVis/stable-diffusion-v1-4"
revision = "bf16" # "flax"

To have more control on the difusion loop we will not use FlaxStableDiffusionPipeline as is but instead instanciate the components of the model individually.

tokenizer = CLIPTokenizer.from_pretrained(model_id, revision=revision, subfolder="tokenizer", dtype=dtype)
text_encoder = FlaxCLIPTextModel.from_pretrained(model_id, revision=revision, subfolder="text_encoder", dtype=dtype)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(model_id, revision=revision, subfolder="vae", dtype=dtype)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(model_id, revision=revision, subfolder="unet", dtype=dtype)
scheduler, scheduler_params = FlaxPNDMScheduler.from_pretrained(model_id, revision=revision, subfolder="scheduler")

Initialize the random seed for JAX.

seed = 123
num_samples = jax.device_count()
prng_seed = jax.random.PRNGKey(seed)
prng_seed = jax.random.split(prng_seed, num_samples)

Set the number of steps in the diffusion loop that are needed to generate one image, as well as the guidance factor.

You can learn more about the effects of those parameters in this article.

guidance_scale = 8 #@param {type:"slider", min:0, max:100, step:0.5}
num_inference_steps = 30 #@param
text_encoder_params = None

Helper functions

We need to convert the input text prompt into tokens and then into embeddings as expected by Stable Diffusion model.

def embed(prompt):
    # prepare_inputs
    text_input = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="np",
        )
    prompt_ids = text_input.input_ids

    # get prompt text embeddings
    text_embeddings = text_encoder(prompt_ids, params=text_encoder_params)[0]

    batch_size = prompt_ids.shape[0]

    max_length = prompt_ids.shape[-1]

    uncond_input = tokenizer(
        [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
    ).input_ids

    uncond_embeddings = text_encoder(uncond_input, params=text_encoder_params)[0]
    context = jnp.concatenate([uncond_embeddings, text_embeddings])
    return context

The following function defines one diffusion step:

  1. Predict noise from the input prompt and noise embeddings
  2. Combine this predicted noise with the noise of a blank string using guidance factor
  3. Pass everthing to a scheduler to calculate the embeddings of a previous noise
def diffusion_step(step, args):
    latents, context, scheduler_state = args
    latents_input = jnp.concatenate([latents] * 2)
    
    t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
    timestep = jnp.broadcast_to(t, latents_input.shape[0])
    
    latents_input = scheduler.scale_model_input(scheduler_state, latents_input, t)
    
    # predict the noise residual
    noise_pred = unet.apply(
        {"params": unet_params},
        jnp.array(latents_input),
        jnp.array(timestep, dtype=jnp.int32),
        encoder_hidden_states=context,
        ).sample
    # perform guidance
    noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)

    # compute the previous noisy sample x_t -> x_t-1
    latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

    return latents, context, scheduler_state

Because we will generate many images for simplification and reuse we group the diffusion logic in the following helper function

def diffuse(context, init_noise):
    latents = init_noise
    # set the timestamps based on the number of steps
    scheduler_state = scheduler.set_timesteps(
        scheduler_params, num_inference_steps=num_inference_steps, shape=latents.shape
        )
    # scale the initial noise by the standard deviation required by the scheduler
    latents = latents * scheduler.init_noise_sigma

    latents, _, _ = jax.lax.fori_loop(0, num_inference_steps, diffusion_step, (latents, context, scheduler_state))
    return latents

Set some global parameters like the expected shape of the embeddings.

vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)

height = unet.config.sample_size * vae_scale_factor
width = unet.config.sample_size * vae_scale_factor

batch_size = 1

latents_shape = (
            batch_size,
            unet.in_channels,
            height // vae_scale_factor,
            width // vae_scale_factor,
        )

The output of Stable Diffusion is an embedding, the following function process it properly and convert it into a PIL image.

def latents_to_pil(latents):
    # scale and decode the image latents with vae
    latents = 1 / 0.18215 * latents
    images = vae.apply({"params": vae_params}, latents, method=vae.decode).sample
    # convert JAX to numpy
    images = (images / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
    images = np.asarray(images)
    # convert numpy array to PIL
    images = (images * 255).round().astype("uint8")
    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image) for image in images]
    return pil_images

Interpolation

The following illustration explains the interpolation approach:

  1. We generate the embeddings for the prompts of the source (e.g. a photo of a model t) and target (e.g. a photo of a mustang) images
  2. We interpolate many new mbeddings between the previous two
  3. For each of those embeddings we generate the corresponding image
  4. We group the resulting images to generate a video or gif

Stable Diffusion Interpolation.png

First, generate the embeddings of the source and target prompts

%%time

prompt_1 = "a realistic photo of a ford model t"
text_embed_1 = embed(prompt_1)
noise_1 = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
latent_1 = diffuse(text_embed_1, noise_1)

prompt_2 = "a realistic photo of a ford mustang"
text_embed_2 = embed(prompt_2)
noise_2 = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
latent_2 = diffuse(text_embed_2, noise_2)
CPU times: user 1h 28min 53s, sys: 1min 24s, total: 1h 30min 17s
Wall time: 47min 42s

The outputs for the previous prompts may not match our expectation and more tweaking maybe needed until we could reach good looking initial images. Thus before going further, let's visualize the embeddings and make sure they look good.

image1 = latents_to_pil(latent_1)[0]
image2 = latents_to_pil(latent_2)[0]
f = plt.figure(figsize=(20,10))
images, prompts = [image1, image2], [prompt_1, prompt_2]
for i, image in enumerate(images):
    sp = f.add_subplot(1, 2, i + 1)
    sp.axis('off')
    sp.set_title(prompts[i], fontsize=16)
    plt.imshow(image)

To generate the $t^{th}$ embeddings vector we use Spherical Linear Interpolation (SLERP) which is defined by the following formula:

$$\operatorname{Slerp}(p_0,p_1; t) = \frac{\sin {[(1-t)\theta}]}{\sin \theta} p_0 + \frac{\sin [t\theta]}{\sin \theta} p_1.$$

Unlinke Linear Interpolation (LERP), SLERP allows us to generate embeddings which will be spaced uniformly.

The following function implements SLERP in JAX:

def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
    """ helper function to spherically interpolate two arrays v1 v2 """

    dot = jnp.sum(v0 * v1 / (jnp.linalg.norm(v0) * jnp.linalg.norm(v1)))
    if jnp.abs(dot) > DOT_THRESHOLD:
        v2 = (1 - t) * v0 + t * v1
    else:
        theta_0 = jnp.arccos(dot)
        sin_theta_0 = jnp.sin(theta_0)
        theta_t = theta_0 * t
        sin_theta_t = jnp.sin(theta_t)
        s0 = jnp.sin(theta_0 - theta_t) / sin_theta_0
        s1 = sin_theta_t / sin_theta_0
        v2 = s0 * v0 + s1 * v1

    return v2

Because generating many images with Stable Diffusion takes time, let's save our progress so far so we could skip all the previous steps in case we needed to restart this notebook.

!mkdir -p '{PATH}/cars'

Save the embeddings

jnp.save(f'{PATH}/cars/text_embed_1.npy', text_embed_1)
jnp.save(f'{PATH}/cars/text_embed_2.npy', text_embed_2)
jnp.save(f'{PATH}/cars/noise_1.npy', noise_1)
jnp.save(f'{PATH}/cars/noise_2.npy', noise_2)
jnp.save(f'{PATH}/cars/latent_1.npy', latent_1)
jnp.save(f'{PATH}/cars/latent_2.npy', latent_2)

Load the embeddings

text_embed_1 = jnp.load(f'{PATH}/cars/text_embed_1.npy')
text_embed_2 = jnp.load(f'{PATH}/cars/text_embed_2.npy')
noise_1 = jnp.load(f'{PATH}/cars/noise_1.npy')
noise_2 = jnp.load(f'{PATH}/cars/noise_2.npy')

The number of interpolation steps is an important parameter because the value depends on how far looking are the source and target images. You many need lot more steps or less, let's pick a value.

num_steps = 50 # @param interpolation step

Lets generate as many as num_steps values for the SLERP parameter $t$.

steps = jnp.linspace(0.0, 1.0, num_steps).tolist()

For each $t$ we interpolate the embeddings based on SLERP then use diffusion loop to generate the corresponding image.

for i, t in enumerate(tqdm(steps)):
  outpath = f'{PATH}/cars/frame{i:06}.jpeg'
  if os.path.exists(outpath):
    continue
  cond_embedding = slerp(t, text_embed_1, text_embed_2)
  initial_noise = slerp(t, noise_1, noise_2)
  latents = diffuse(cond_embedding, initial_noise)
  image = latents_to_pil(latents)[0]
  image.save(outpath)

After generate all the intermediate images, we can group them into a video or GIF as follows.

!ffmpeg -r 10 -f image2 -s 512x512 -i '{PATH}'/cars/frame%06d.jpeg -vcodec libx264 -crf 10 -pix_fmt yuv420p evolution_cars.mp4
!ffmpeg -f image2 -framerate 1 -i '{PATH}'/cars/frame%06d.jpeg -loop -1 evolution_cars.gif
!ls -alt evolution_cars.*
-rw-r--r-- 1 root root 3871184 Dec 14 10:19 evolution_cars.gif
-rw-r--r-- 1 root root 2466461 Dec 14 10:19 evolution_cars.mp4

The resulting interpolation is illustrated in the video in at the top of this article.

That's all folks

Stable Diffusion is a powerful model with many applications. In this article we saw how to combine it with interpolation techniques to generated a compelling sequence of images.

I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.