Stable Diffusion is a powerful text-to-image model. Its success lead many websites and tools to provide easy access to it so that anyone can use it to generate images from text. It is also integrated into the Huggingface diffusers library where generating images in python can be as simple as:
from diffusers import FlaxStableDiffusionPipeline
pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jax.numpy.bfloat16
)
prng_seed = jax.random.split(jax.random.PRNGKey(seed), 1)
prompt_ids = pipeline.prepare_inputs([prompt])
images = pipeline(prompt_ids, params, prng_seed).images
images = pipeline.numpy_to_pil(images)
You can refer to this article for a complete walk-through of how to use the FlaxStableDiffusionPipeline
API to generate images from your prompt.
In Part I we will dig into the actual code behind such an easy-to-use API to better understand the Stable Diffusion model. First, we will re-create the functionality of FlaxStableDiffusionPipeline
step by step, and then we will inspect its main components.
By the end of this notebook we will have a good undertanding of how Stable Diffusion works and be able to tweak and modify the inner working.
Stable Diffusion is trained to remove noise from an image. By repeating this process mutiple times, it is able to generate images of great quality. How much noise is removed in every step depends on the input textual description. This helps guide the model toward generating an image that matches the input description.
This is how Stable Diffusion generate text from image
- The textual description of the target image is passed through a text encoder to generate the text embeddings
- Random noisy image latent (think of it as a compressed of shape
64 x 64
) is generate as a starting point - The noise is passed to a U-Net model along with the text embeddings
- The U-Net generates a denoised image latent conditioned by the text embeddings
- The denoised predicted image is passed through a Sechduler which will add little bit of noise based on the current number of steps
- The output of the scheduler is used as input for next denoising step
- The denoising loop is repeated couple steps
- The final noisy image latent is passed to a Decoder which will generate the final image with shape
512 x 512
.
The above diagram illustrates the different components of Stable Diffusion and how they are combined to generate an image from text. The remaining of this artcile we will dive deeper into the implementation details.
%%capture
%%bash
pip install --upgrade scipy flax
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade diffusers transformers
Log into your Hugging Face account with an access token and accept the terms of the licence for this model - see model card.
!huggingface-cli login
Import the libraries we will use, e.g. matplotlib for plotting the resulting images.
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel, FlaxPNDMScheduler
Model and weights
This will download and set up the relevant models and components we'll be using. Let's just run this for now and move on to the next section to check that it all works before diving deeper.
Let's first instantiate the different components of Stable Diffusion which are representation with the following classes:
CLIPTokenizer
used to transform the text prompt into token IDs.FlaxCLIPTextModel
used to encode token IDs into the corresponding embeddingsFlaxAutoencoderKL
or Variational Autoencoder used for encoding/decoding images to/from a latent representationFlaxUNet2DConditionModel
conditionalU-Net
model used to denoise the latent representation of an imageFlaxPNDMScheduler
used to add noise to the latent representation of an image
We will download and set up the previous components using checkpoints of Stable Diffusion from CompVis/stable-diffusion-v1-4.
bf16
) for a faster weights download, to save on memory and avoid issues like slower predictions or OOM. If you want use float32 precision instead then use revision flax
.
dtype = jax.numpy.float16
model_id = "CompVis/stable-diffusion-v1-4"
revision = "bf16" # "flax"
tokenizer = CLIPTokenizer.from_pretrained(model_id, revision=revision, subfolder="tokenizer", dtype=dtype)
text_encoder = FlaxCLIPTextModel.from_pretrained(model_id, revision=revision, subfolder="text_encoder", dtype=dtype)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(model_id, revision=revision, subfolder="vae", dtype=dtype)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(model_id, revision=revision, subfolder="unet", dtype=dtype)
scheduler, scheduler_params = FlaxPNDMScheduler.from_pretrained(model_id, revision=revision, subfolder="scheduler")
Note: Instead of individually loading those components we could also load a pipeline and then access them using pipe.unet
, pipe.vae
and so on. Something like this:
from diffusers import FlaxStableDiffusionPipeline
pipe, pipe_params = FlaxStableDiffusionPipeline.from_pretrained(model_id, revision=revision, dtype=jax.numpy.bfloat16)
unet, vae = pipe.unet, pipe.vae
First, let's initialize the random seed so we can reproduce the results.
seed = 123
num_samples = jax.device_count()
prng_seed = jax.random.PRNGKey(seed)
prng_seed = jax.random.split(prng_seed, num_samples)
The following are two important parameters:
num_inference_steps
corresponds to the number of steps to take in the diffusion loopguidance_scale
is used to regulate how the noisy predicted image will be update
guidance_scale = 7.5 #@param {type:"slider", min:0, max:100, step:0.5}
num_inference_steps = 30 #@param
Choose a random text prompt that describes the target image
prompt = "a photo of a car orbiting earth in van Gogh style"
text_encoder_params = None
Initialize some parameters
# init
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
# call
height = unet.config.sample_size * vae_scale_factor
width = unet.config.sample_size * vae_scale_factor
Now, we need to use the tokenizer
and text_encoder
to calculate the embeddings for the input text prompt.
%%time
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
prompt_ids = text_input.input_ids
text_embeddings = text_encoder(prompt_ids, params=text_encoder_params)[0]
batch_size = prompt_ids.shape[0]
max_length = prompt_ids.shape[-1]
We also need to calculate the embeddings for the blank text. This embedding will be later combined with the predicted noiy image using the previous guidance_factor
.
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
).input_ids
uncond_embeddings = text_encoder(uncond_input, params=text_encoder_params)[0]
Here we concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
context = jnp.concatenate([uncond_embeddings, text_embeddings])
Now we define a diffusion step where we use the unet
component to generate noise conditioned by the prompt embeddings. Then, combine the predicted noise with the embeddings form the blank text.
def diffusion_step(step, args):
latents, scheduler_state = args
latents_input = jnp.concatenate([latents] * 2)
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])
latents_input = scheduler.scale_model_input(scheduler_state, latents_input, t)
# predict the noise residual
noise_pred = unet.apply(
{"params": unet_params},
jnp.array(latents_input),
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=context,
).sample
# perform guidance
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, scheduler_state
latents_shape = (
batch_size,
unet.in_channels,
height // vae_scale_factor,
width // vae_scale_factor,
)
The following defines the starting latents which is is randomly sampled from a Normal distribution.
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
Here is the diffusion loop where we will call the previously defined diffusion_step
as many steps as num_inference_steps
.
# set the timestamps based on the number of steps
scheduler_state = scheduler.set_timesteps(
scheduler_params, num_inference_steps=num_inference_steps, shape=latents.shape
)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * scheduler.init_noise_sigma
latents, _ = jax.lax.fori_loop(0, num_inference_steps, diffusion_step, (latents, scheduler_state))
latents.shape
Let's definte the following helper function to convert the model output latent embeddings into an actual PIL image for plotting.
def latents_to_pil(latents):
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
images = vae.apply({"params": vae_params}, latents, method=vae.decode).sample
# convert JAX to numpy
images = (images / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
images = np.asarray(images)
# convert numpy array to PIL
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
Now we use the Autoencoder (through the previous helper function) to generate an image from its latent representation
%%time
images = latents_to_pil(latents)
Finally, we can check the resuling image
images[0]
That's all folks
Stable Diffusion is a complex model that comprises of many components and uses clever tricks so it can be trained with relatively cheap hardware compared to other difussion models (e.g. Google's Imagen or OpenAI’s DALL-E 2).
In this article, we dived deep into the Flax implementation of the Stable Diffusion model to better understand how it works.
I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.