Untitled Diagram.drawio.png

Stable Diffusion is a powerful text-to-image model. Its success lead many websites and tools to provide easy access to it so that anyone can use it to generate images from text. It is also integrated into the Huggingface diffusers library where generating images in python can be as simple as:

from diffusers import FlaxStableDiffusionPipeline

pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jax.numpy.bfloat16
)

prng_seed = jax.random.split(jax.random.PRNGKey(seed), 1)
prompt_ids = pipeline.prepare_inputs([prompt])
images = pipeline(prompt_ids, params, prng_seed).images
images = pipeline.numpy_to_pil(images)

You can refer to this article for a complete walk-through of how to use the FlaxStableDiffusionPipeline API to generate images from your prompt.

In Part I we will dig into the actual code behind such an easy-to-use API to better understand the Stable Diffusion model. First, we will re-create the functionality of FlaxStableDiffusionPipeline step by step, and then we will inspect its main components.

By the end of this notebook we will have a good undertanding of how Stable Diffusion works and be able to tweak and modify the inner working.

Introduction

Stable Diffusion is trained to remove noise from an image. By repeating this process mutiple times, it is able to generate images of great quality. How much noise is removed in every step depends on the input textual description. This helps guide the model toward generating an image that matches the input description.

This is how Stable Diffusion generate text from image

  1. The textual description of the target image is passed through a text encoder to generate the text embeddings
  2. Random noisy image latent (think of it as a compressed of shape 64 x 64) is generate as a starting point
  3. The noise is passed to a U-Net model along with the text embeddings
  4. The U-Net generates a denoised image latent conditioned by the text embeddings
  5. The denoised predicted image is passed through a Sechduler which will add little bit of noise based on the current number of steps
  6. The output of the scheduler is used as input for next denoising step
  7. The denoising loop is repeated couple steps
  8. The final noisy image latent is passed to a Decoder which will generate the final image with shape 512 x 512.

Untitled Diagram.drawio.svg

The above diagram illustrates the different components of Stable Diffusion and how they are combined to generate an image from text. The remaining of this artcile we will dive deeper into the implementation details.

Setup and Imports

First, we need to install some libraries include diffusers and Flax.

%%capture
%%bash

pip install --upgrade scipy flax
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade diffusers transformers

Log into your Hugging Face account with an access token and accept the terms of the licence for this model - see model card.

!huggingface-cli login
    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` now requires a token generated from https://huggingface.co/settings/tokens .
    
Token: 
Add token as git credential? (Y/n) n
Token is valid.
Your token has been saved to /root/.huggingface/token
Login successful

Import the libraries we will use, e.g. matplotlib for plotting the resulting images.

import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel, FlaxPNDMScheduler

Model and weights

This will download and set up the relevant models and components we'll be using. Let's just run this for now and move on to the next section to check that it all works before diving deeper.

Let's first instantiate the different components of Stable Diffusion which are representation with the following classes:

  • CLIPTokenizer used to transform the text prompt into token IDs.
  • FlaxCLIPTextModel used to encode token IDs into the corresponding embeddings
  • FlaxAutoencoderKL or Variational Autoencoder used for encoding/decoding images to/from a latent representation
  • FlaxUNet2DConditionModel conditional U-Net model used to denoise the latent representation of an image
  • FlaxPNDMScheduler used to add noise to the latent representation of an image

We will download and set up the previous components using checkpoints of Stable Diffusion from CompVis/stable-diffusion-v1-4.

Note: We will use half-precision (i.e. revision bf16) for a faster weights download, to save on memory and avoid issues like slower predictions or OOM. If you want use float32 precision instead then use revision flax.

dtype = jax.numpy.float16
model_id = "CompVis/stable-diffusion-v1-4"
revision = "bf16" # "flax"
tokenizer = CLIPTokenizer.from_pretrained(model_id, revision=revision, subfolder="tokenizer", dtype=dtype)
text_encoder = FlaxCLIPTextModel.from_pretrained(model_id, revision=revision, subfolder="text_encoder", dtype=dtype)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(model_id, revision=revision, subfolder="vae", dtype=dtype)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(model_id, revision=revision, subfolder="unet", dtype=dtype)
scheduler, scheduler_params = FlaxPNDMScheduler.from_pretrained(model_id, revision=revision, subfolder="scheduler")

Note: Instead of individually loading those components we could also load a pipeline and then access them using pipe.unet, pipe.vae and so on. Something like this:

from diffusers import FlaxStableDiffusionPipeline

pipe, pipe_params = FlaxStableDiffusionPipeline.from_pretrained(model_id, revision=revision, dtype=jax.numpy.bfloat16)
unet, vae = pipe.unet, pipe.vae

Diffusion Loop

First, let's initialize the random seed so we can reproduce the results.

seed = 123
num_samples = jax.device_count()
prng_seed = jax.random.PRNGKey(seed)
prng_seed = jax.random.split(prng_seed, num_samples)

The following are two important parameters:

  • num_inference_steps corresponds to the number of steps to take in the diffusion loop
  • guidance_scale is used to regulate how the noisy predicted image will be update
guidance_scale = 7.5 #@param {type:"slider", min:0, max:100, step:0.5}
num_inference_steps = 30 #@param

Choose a random text prompt that describes the target image

prompt = "a photo of a car orbiting earth in van Gogh style"
text_encoder_params = None

Initialize some parameters

# init
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
# call
height = unet.config.sample_size * vae_scale_factor
width = unet.config.sample_size * vae_scale_factor

Now, we need to use the tokenizer and text_encoder to calculate the embeddings for the input text prompt.

%%time

text_input = tokenizer(
            prompt,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="np",
        )
prompt_ids = text_input.input_ids

text_embeddings = text_encoder(prompt_ids, params=text_encoder_params)[0]

batch_size = prompt_ids.shape[0]

max_length = prompt_ids.shape[-1]

We also need to calculate the embeddings for the blank text. This embedding will be later combined with the predicted noiy image using the previous guidance_factor.

uncond_input = tokenizer(
  [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
).input_ids

uncond_embeddings = text_encoder(uncond_input, params=text_encoder_params)[0]

Here we concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes

context = jnp.concatenate([uncond_embeddings, text_embeddings])

Now we define a diffusion step where we use the unet component to generate noise conditioned by the prompt embeddings. Then, combine the predicted noise with the embeddings form the blank text.

def diffusion_step(step, args):
    latents, scheduler_state = args
    
    latents_input = jnp.concatenate([latents] * 2)
    
    t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
    timestep = jnp.broadcast_to(t, latents_input.shape[0])
    
    latents_input = scheduler.scale_model_input(scheduler_state, latents_input, t)
    
    # predict the noise residual
    noise_pred = unet.apply(
        {"params": unet_params},
        jnp.array(latents_input),
        jnp.array(timestep, dtype=jnp.int32),
        encoder_hidden_states=context,
        ).sample
    # perform guidance
    noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
    
    # compute the previous noisy sample x_t -> x_t-1
    latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
    
    return latents, scheduler_state
latents_shape = (
            batch_size,
            unet.in_channels,
            height // vae_scale_factor,
            width // vae_scale_factor,
        )

The following defines the starting latents which is is randomly sampled from a Normal distribution.

latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)

Here is the diffusion loop where we will call the previously defined diffusion_step as many steps as num_inference_steps.

# set the timestamps based on the number of steps
scheduler_state = scheduler.set_timesteps(
    scheduler_params, num_inference_steps=num_inference_steps, shape=latents.shape
    )
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * scheduler.init_noise_sigma

latents, _ = jax.lax.fori_loop(0, num_inference_steps, diffusion_step, (latents, scheduler_state))
latents.shape
(1, 4, 64, 64)

Let's definte the following helper function to convert the model output latent embeddings into an actual PIL image for plotting.

def latents_to_pil(latents):
    # scale and decode the image latents with vae
    latents = 1 / 0.18215 * latents
    images = vae.apply({"params": vae_params}, latents, method=vae.decode).sample
    # convert JAX to numpy
    images = (images / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
    images = np.asarray(images)
    # convert numpy array to PIL
    images = (images * 255).round().astype("uint8")
    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image) for image in images]
    return pil_images

Now we use the Autoencoder (through the previous helper function) to generate an image from its latent representation

%%time

images = latents_to_pil(latents)
CPU times: user 1min 4s, sys: 1.44 s, total: 1min 6s
Wall time: 38.4 s

Finally, we can check the resuling image

images[0]

That's all folks

Stable Diffusion is a complex model that comprises of many components and uses clever tricks so it can be trained with relatively cheap hardware compared to other difussion models (e.g. Google's Imagen or OpenAI’s DALL-E 2).

In this article, we dived deep into the Flax implementation of the Stable Diffusion model to better understand how it works.

I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.