In Part I we re-created the functionality of FlaxStableDiffusionPipeline
step by step and gained a better understanding of the inner working of the diffusion loop. In Part II, we will inspect each of the main components of Stable Diffusion.
%%capture
%%bash
pip install --upgrade scipy flax
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade diffusers transformers
Log into your Hugging Face account with an access token and accept the terms of the licence for this model - see model card.
!huggingface-cli login
Import the libraries we will use, e.g. matplotlib for plotting the resulting images.
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel, FlaxPNDMScheduler
Model and weights
Refer to Part I for more details on the model and the checkpoint we are choosing. Let's just run those cells to instantiate the components.
dtype = jax.numpy.float16
model_id = "CompVis/stable-diffusion-v1-4"
revision = "bf16" # "flax"
tokenizer = CLIPTokenizer.from_pretrained(model_id, revision=revision, subfolder="tokenizer", dtype=dtype)
text_encoder = FlaxCLIPTextModel.from_pretrained(model_id, revision=revision, subfolder="text_encoder", dtype=dtype)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(model_id, revision=revision, subfolder="vae", dtype=dtype)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(model_id, revision=revision, subfolder="unet", dtype=dtype)
scheduler, scheduler_params = FlaxPNDMScheduler.from_pretrained(model_id, revision=revision, subfolder="scheduler")
The following diagram illustrates the different components of Stable Diffusion and how they are combined to generate an image from text. For details on how those components work together to make a diffusion loop refer to Part I.
In this section we will play with some of those components to better understand them.
First, let's initialize the random seed so we can reproduce the results.
seed = 123
num_samples = jax.device_count()
prng_seed = jax.random.PRNGKey(seed)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt = "a photo of a car orbiting earth in van Gogh style"
The tokenizer turns the prompt into token IDs.
token_ids = tokenizer.encode(prompt)
token_ids
Some tokens have special meaning, for instance the tokens:
<|startoftext|>
with id49406
is added to the begning of the prompt<|endoftext|>
with id49407
is added to the end of the prompt
tokenizer.decode(49406), tokenizer.decode(49407)
When the target max_length
is much bigger than the prompt length, padding is added to the end using the <|endoftext|>
token.
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input.input_ids
Using those token IDs we covert them into embeddings like this
text_embeddings = text_encoder(text_input.input_ids)[0]
text_embeddings.shape
The Autoencoder (AE)
The Autoencoder (AE) is a very important component of Stable Diffusion. Its main purpose is to 'encode' an image into a latent representation, and also decode this latent back into the original image. By doing this, it is able to significaly reduce the input image size (in fact by a factor of 64
) without loosing much of information.
This capability allows Stable Diffusion to perform the denoising on the latent representation instead of using the original image, hence the less memory footprint and compuation efficiency of this model.
In this section, we will walkthrough how AE can encode an image and decode it back without loosing information.
!curl -s -o flower.jpeg https://images.unsplash.com/photo-1604085572504-a392ddf0d86a?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=224&q=224
pil_image = Image.open('flower.jpeg')
pil_image.resize((256, 256))
Let's definte a helper function to convert a PIL image into latent embeddings using the AE Encoder.
def pil_to_latents(pil_image):
# Single image -> single latent in a batch (so size 1, 4, 64, 64)
image = np.asarray(pil_image)
image = jax.image.resize(image, (512, 512, 3), "bicubic")
image = (image / 127.5 - 1.0).astype(np.float32)
input_im = jnp.expand_dims(image, axis=0)
input_im = jnp.transpose(input_im, (0, 3, 1, 2))
# encode the image
latents = vae.apply({"params": vae_params}, input_im, method=vae.encode)
return 0.18215 * latents.latent_dist.sample(prng_seed)
Take the image and pass it through the AE encoder to generate its latent embeddings
%%time
flower_latents = pil_to_latents(pil_image)
The latent is of shape 64 x 64 x 4
, let's define the following helper function to plot each channel.
def plot_latents(latents, figsize=(10, 5), maintitle=None):
latents = latents.squeeze()
rows, cols = 1, latents.shape[-1]
f = plt.figure(figsize=figsize)
if maintitle is not None: plt.suptitle(maintitle, fontsize=10)
for i in range(cols):
sp = f.add_subplot(rows, cols, i+1)
sp.axis('Off')
sp.set_title('Channel '+str(i), fontsize=16)
img = np.asarray(latents[:, :, i])
plt.imshow(img)
Plotting the different channel we can see that the latent preserved a lot of the characteristics of the original flower image.
plot_latents(flower_latents)
Let's definte the following helper function to convert the model output latent embeddings into an actual PIL image for plotting.
def latents_to_pil(latents):
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
images = vae.apply({"params": vae_params}, latents, method=vae.decode).sample
# convert JAX to numpy
images = (images / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
images = np.asarray(images)
# convert numpy array to PIL
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
From those latents we can go back to our exact original picture (more or less).
%%time
images = latents_to_pil(flower_latents)
See how AE decoder generate an image very close to the original one.
images[0].resize((256, 256))
The Scheduler
As part of the diffusion loop, we add some noise to an image an then let the model try to predict a denoised image. If we always add too much noise, the model will have harder time denoising the image. If we add a tiny fraction, the model won't be able to do much with the random starting points we use for sampling. The Scheduler is used to regulate the amount of noise to apply at each step of the diffusion loop, according to some distribution.
In this section, we will use the previously created scheduler to add some noise to our test image. The amount of noise to add corresponds to the noise from the step with number sampling_step
out of total_steps
.
encoded = flower_latents
Equivalent to step 10 out of 15 in the scheduler above
total_steps = 15 # @param
sampling_step = 10 # @param
scheduler_state = scheduler.set_timesteps(scheduler_params, num_inference_steps=total_steps, shape=encoded.shape)
noise = jax.random.normal(prng_seed, shape=encoded.shape, dtype=jnp.float32) # Random noise
encoded_and_noised = scheduler.add_noise(encoded, noise, timesteps=scheduler_state.timesteps[sampling_step])
Decode latents into an image
noised = latents_to_pil(encoded_and_noised)[0]
See how the image looks like when the scheduler applies noise at step number 10:
noised.resize((256, 256)) # Display
That's all folks
Stable Diffusion is a complex model that comprises of many components. In this article, we played with some of those components in Flax.
I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.