Untitled Diagram.drawio.png

In Part I we re-created the functionality of FlaxStableDiffusionPipeline step by step and gained a better understanding of the inner working of the diffusion loop. In Part II, we will inspect each of the main components of Stable Diffusion.

Setup and Imports

First, we need to install some libraries include diffusers and Flax.

%%capture
%%bash

pip install --upgrade scipy flax
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade diffusers transformers

Log into your Hugging Face account with an access token and accept the terms of the licence for this model - see model card.

!huggingface-cli login
    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` now requires a token generated from https://huggingface.co/settings/tokens .
    
Token: 
Add token as git credential? (Y/n) n
Token is valid.
Your token has been saved to /root/.huggingface/token
Login successful

Import the libraries we will use, e.g. matplotlib for plotting the resulting images.

import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel, FlaxPNDMScheduler

Model and weights

Refer to Part I for more details on the model and the checkpoint we are choosing. Let's just run those cells to instantiate the components.

dtype = jax.numpy.float16
model_id = "CompVis/stable-diffusion-v1-4"
revision = "bf16" # "flax"
tokenizer = CLIPTokenizer.from_pretrained(model_id, revision=revision, subfolder="tokenizer", dtype=dtype)
text_encoder = FlaxCLIPTextModel.from_pretrained(model_id, revision=revision, subfolder="text_encoder", dtype=dtype)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(model_id, revision=revision, subfolder="vae", dtype=dtype)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(model_id, revision=revision, subfolder="unet", dtype=dtype)
scheduler, scheduler_params = FlaxPNDMScheduler.from_pretrained(model_id, revision=revision, subfolder="scheduler")

Components of Stable Diffusion

The following diagram illustrates the different components of Stable Diffusion and how they are combined to generate an image from text. For details on how those components work together to make a diffusion loop refer to Part I.

In this section we will play with some of those components to better understand them.

Untitled Diagram.drawio.svg

First, let's initialize the random seed so we can reproduce the results.

seed = 123
num_samples = jax.device_count()
prng_seed = jax.random.PRNGKey(seed)
prng_seed = jax.random.split(prng_seed, num_samples)

The Text Encoder

The first component of the Stable Diffusion model is the Text Encoder that turns the prompt text into embeddings.

prompt = "a photo of a car orbiting earth in van Gogh style"

The tokenizer turns the prompt into token IDs.

token_ids = tokenizer.encode(prompt)
token_ids
[49406,
 320,
 1125,
 539,
 320,
 1615,
 523,
 23016,
 3475,
 530,
 2451,
 19697,
 1844,
 49407]

Some tokens have special meaning, for instance the tokens:

  • <|startoftext|> with id 49406 is added to the begning of the prompt
  • <|endoftext|> with id 49407 is added to the end of the prompt
tokenizer.decode(49406), tokenizer.decode(49407)
('<|startoftext|>', '<|endoftext|>')

When the target max_length is much bigger than the prompt length, padding is added to the end using the <|endoftext|> token.

text_input = tokenizer(
            prompt,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="np",
        )
text_input.input_ids
array([[49406,   320,  1125,   539,   320,  1615,   523, 23016,  3475,
          530,  2451, 19697,  1844, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407]])

Using those token IDs we covert them into embeddings like this

text_embeddings = text_encoder(text_input.input_ids)[0]
text_embeddings.shape
(1, 77, 768)

The Autoencoder (AE)

The Autoencoder (AE) is a very important component of Stable Diffusion. Its main purpose is to 'encode' an image into a latent representation, and also decode this latent back into the original image. By doing this, it is able to significaly reduce the input image size (in fact by a factor of 64) without loosing much of information.

This capability allows Stable Diffusion to perform the denoising on the latent representation instead of using the original image, hence the less memory footprint and compuation efficiency of this model.

In this section, we will walkthrough how AE can encode an image and decode it back without loosing information.

Encoding

!curl -s -o flower.jpeg https://images.unsplash.com/photo-1604085572504-a392ddf0d86a?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=224&q=224 
pil_image = Image.open('flower.jpeg')
pil_image.resize((256, 256))

Let's definte a helper function to convert a PIL image into latent embeddings using the AE Encoder.

def pil_to_latents(pil_image):
    # Single image -> single latent in a batch (so size 1, 4, 64, 64)
    image = np.asarray(pil_image)
    image = jax.image.resize(image, (512, 512, 3), "bicubic")
    image = (image / 127.5 - 1.0).astype(np.float32)
    input_im = jnp.expand_dims(image, axis=0)
    input_im = jnp.transpose(input_im, (0, 3, 1, 2))
    # encode the image
    latents = vae.apply({"params": vae_params}, input_im, method=vae.encode)
    return 0.18215 * latents.latent_dist.sample(prng_seed)

Take the image and pass it through the AE encoder to generate its latent embeddings

%%time

flower_latents = pil_to_latents(pil_image)
CPU times: user 35.1 s, sys: 1.2 s, total: 36.3 s
Wall time: 26.5 s

Latent

The latent is of shape 64 x 64 x 4, let's define the following helper function to plot each channel.

def plot_latents(latents, figsize=(10, 5), maintitle=None):
    latents = latents.squeeze()
    rows, cols = 1, latents.shape[-1]
    f = plt.figure(figsize=figsize)
    if maintitle is not None: plt.suptitle(maintitle, fontsize=10)
    for i in range(cols):
        sp = f.add_subplot(rows, cols, i+1)
        sp.axis('Off')
        sp.set_title('Channel '+str(i), fontsize=16)
        img = np.asarray(latents[:, :, i])
        plt.imshow(img)

Plotting the different channel we can see that the latent preserved a lot of the characteristics of the original flower image.

plot_latents(flower_latents)

Decoding

Let's definte the following helper function to convert the model output latent embeddings into an actual PIL image for plotting.

def latents_to_pil(latents):
    # scale and decode the image latents with vae
    latents = 1 / 0.18215 * latents
    images = vae.apply({"params": vae_params}, latents, method=vae.decode).sample
    # convert JAX to numpy
    images = (images / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
    images = np.asarray(images)
    # convert numpy array to PIL
    images = (images * 255).round().astype("uint8")
    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image) for image in images]
    return pil_images

From those latents we can go back to our exact original picture (more or less).

%%time

images = latents_to_pil(flower_latents)
CPU times: user 1min 5s, sys: 1.41 s, total: 1min 7s
Wall time: 54.2 s

See how AE decoder generate an image very close to the original one.

images[0].resize((256, 256))

The Scheduler

As part of the diffusion loop, we add some noise to an image an then let the model try to predict a denoised image. If we always add too much noise, the model will have harder time denoising the image. If we add a tiny fraction, the model won't be able to do much with the random starting points we use for sampling. The Scheduler is used to regulate the amount of noise to apply at each step of the diffusion loop, according to some distribution.

In this section, we will use the previously created scheduler to add some noise to our test image. The amount of noise to add corresponds to the noise from the step with number sampling_step out of total_steps.

encoded = flower_latents

Equivalent to step 10 out of 15 in the scheduler above

total_steps = 15 # @param
sampling_step = 10 # @param
scheduler_state = scheduler.set_timesteps(scheduler_params, num_inference_steps=total_steps, shape=encoded.shape)
noise = jax.random.normal(prng_seed, shape=encoded.shape, dtype=jnp.float32) # Random noise
encoded_and_noised = scheduler.add_noise(encoded, noise, timesteps=scheduler_state.timesteps[sampling_step])

Decode latents into an image

noised = latents_to_pil(encoded_and_noised)[0]

See how the image looks like when the scheduler applies noise at step number 10:

noised.resize((256, 256)) # Display

That's all folks

Stable Diffusion is a complex model that comprises of many components. In this article, we played with some of those components in Flax.

I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.