Stable Diffusion is a powerful model that can used to generate an image form random text prompts. Hugging Face's diffusers library provides access to such model in a very convenient way.
In this article we will use this library to load the Flax-based Stable Diffusion model and its weights to generate a sequence of interpolated images.
%%capture
%%bash
pip install --upgrade scipy flax
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade diffusers transformers
To be able to use Stable Diffusion we need to login to Hugging Face and accept the terms of the licence for this model.
!huggingface-cli login
import os
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
Generating many images with Stable Diffusion in Colab can take time, so to save progress if we had to restart we need to mount Google Drive (or other persistent storage).
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
PATH = '/content/drive/My Drive/Colab Notebooks/diffusion'
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel, FlaxPNDMScheduler
dtype = jax.numpy.float16
We will use the checkpoints of Stable Diffusion from CompVis/stable-diffusion-v1-4 with half-precision (i.e. revision bf16
) to save on memory and avoid issues like OOM.
flax
model_id = "CompVis/stable-diffusion-v1-4"
revision = "bf16" # "flax"
To have more control on the difusion loop we will not use FlaxStableDiffusionPipeline
as is but instead instanciate the components of the model individually.
tokenizer = CLIPTokenizer.from_pretrained(model_id, revision=revision, subfolder="tokenizer", dtype=dtype)
text_encoder = FlaxCLIPTextModel.from_pretrained(model_id, revision=revision, subfolder="text_encoder", dtype=dtype)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(model_id, revision=revision, subfolder="vae", dtype=dtype)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(model_id, revision=revision, subfolder="unet", dtype=dtype)
scheduler, scheduler_params = FlaxPNDMScheduler.from_pretrained(model_id, revision=revision, subfolder="scheduler")
Initialize the random seed for JAX.
seed = 123
num_samples = jax.device_count()
prng_seed = jax.random.PRNGKey(seed)
prng_seed = jax.random.split(prng_seed, num_samples)
Set the number of steps in the diffusion loop that are needed to generate one image, as well as the guidance factor.
You can learn more about the effects of those parameters in this article.
guidance_scale = 8 #@param {type:"slider", min:0, max:100, step:0.5}
num_inference_steps = 30 #@param
text_encoder_params = None
We need to convert the input text prompt into tokens and then into embeddings as expected by Stable Diffusion model.
def embed(prompt):
# prepare_inputs
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
prompt_ids = text_input.input_ids
# get prompt text embeddings
text_embeddings = text_encoder(prompt_ids, params=text_encoder_params)[0]
batch_size = prompt_ids.shape[0]
max_length = prompt_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
).input_ids
uncond_embeddings = text_encoder(uncond_input, params=text_encoder_params)[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
return context
The following function defines one diffusion step:
- Predict noise from the input prompt and noise embeddings
- Combine this predicted noise with the noise of a blank string using guidance factor
- Pass everthing to a scheduler to calculate the embeddings of a previous noise
def diffusion_step(step, args):
latents, context, scheduler_state = args
latents_input = jnp.concatenate([latents] * 2)
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])
latents_input = scheduler.scale_model_input(scheduler_state, latents_input, t)
# predict the noise residual
noise_pred = unet.apply(
{"params": unet_params},
jnp.array(latents_input),
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=context,
).sample
# perform guidance
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, context, scheduler_state
Because we will generate many images for simplification and reuse we group the diffusion logic in the following helper function
def diffuse(context, init_noise):
latents = init_noise
# set the timestamps based on the number of steps
scheduler_state = scheduler.set_timesteps(
scheduler_params, num_inference_steps=num_inference_steps, shape=latents.shape
)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * scheduler.init_noise_sigma
latents, _, _ = jax.lax.fori_loop(0, num_inference_steps, diffusion_step, (latents, context, scheduler_state))
return latents
Set some global parameters like the expected shape of the embeddings.
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
height = unet.config.sample_size * vae_scale_factor
width = unet.config.sample_size * vae_scale_factor
batch_size = 1
latents_shape = (
batch_size,
unet.in_channels,
height // vae_scale_factor,
width // vae_scale_factor,
)
The output of Stable Diffusion is an embedding, the following function process it properly and convert it into a PIL image.
def latents_to_pil(latents):
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
images = vae.apply({"params": vae_params}, latents, method=vae.decode).sample
# convert JAX to numpy
images = (images / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
images = np.asarray(images)
# convert numpy array to PIL
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
The following illustration explains the interpolation approach:
- We generate the embeddings for the prompts of the source (e.g. a photo of a model t) and target (e.g. a photo of a mustang) images
- We interpolate many new mbeddings between the previous two
- For each of those embeddings we generate the corresponding image
- We group the resulting images to generate a video or gif
First, generate the embeddings of the source and target prompts
%%time
prompt_1 = "a realistic photo of a ford model t"
text_embed_1 = embed(prompt_1)
noise_1 = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
latent_1 = diffuse(text_embed_1, noise_1)
prompt_2 = "a realistic photo of a ford mustang"
text_embed_2 = embed(prompt_2)
noise_2 = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
latent_2 = diffuse(text_embed_2, noise_2)
The outputs for the previous prompts may not match our expectation and more tweaking maybe needed until we could reach good looking initial images. Thus before going further, let's visualize the embeddings and make sure they look good.
image1 = latents_to_pil(latent_1)[0]
image2 = latents_to_pil(latent_2)[0]
f = plt.figure(figsize=(20,10))
images, prompts = [image1, image2], [prompt_1, prompt_2]
for i, image in enumerate(images):
sp = f.add_subplot(1, 2, i + 1)
sp.axis('off')
sp.set_title(prompts[i], fontsize=16)
plt.imshow(image)
To generate the $t^{th}$ embeddings vector we use Spherical Linear Interpolation (SLERP) which is defined by the following formula:
$$\operatorname{Slerp}(p_0,p_1; t) = \frac{\sin {[(1-t)\theta}]}{\sin \theta} p_0 + \frac{\sin [t\theta]}{\sin \theta} p_1.$$
Unlinke Linear Interpolation (LERP), SLERP allows us to generate embeddings which will be spaced uniformly.
The following function implements SLERP in JAX:
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
""" helper function to spherically interpolate two arrays v1 v2 """
dot = jnp.sum(v0 * v1 / (jnp.linalg.norm(v0) * jnp.linalg.norm(v1)))
if jnp.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = jnp.arccos(dot)
sin_theta_0 = jnp.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = jnp.sin(theta_t)
s0 = jnp.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
return v2
Because generating many images with Stable Diffusion takes time, let's save our progress so far so we could skip all the previous steps in case we needed to restart this notebook.
!mkdir -p '{PATH}/cars'
Save the embeddings
jnp.save(f'{PATH}/cars/text_embed_1.npy', text_embed_1)
jnp.save(f'{PATH}/cars/text_embed_2.npy', text_embed_2)
jnp.save(f'{PATH}/cars/noise_1.npy', noise_1)
jnp.save(f'{PATH}/cars/noise_2.npy', noise_2)
jnp.save(f'{PATH}/cars/latent_1.npy', latent_1)
jnp.save(f'{PATH}/cars/latent_2.npy', latent_2)
Load the embeddings
text_embed_1 = jnp.load(f'{PATH}/cars/text_embed_1.npy')
text_embed_2 = jnp.load(f'{PATH}/cars/text_embed_2.npy')
noise_1 = jnp.load(f'{PATH}/cars/noise_1.npy')
noise_2 = jnp.load(f'{PATH}/cars/noise_2.npy')
The number of interpolation steps is an important parameter because the value depends on how far looking are the source and target images. You many need lot more steps or less, let's pick a value.
num_steps = 50 # @param interpolation step
Lets generate as many as num_steps
values for the SLERP parameter $t$.
steps = jnp.linspace(0.0, 1.0, num_steps).tolist()
For each $t$ we interpolate the embeddings based on SLERP then use diffusion loop to generate the corresponding image.
for i, t in enumerate(tqdm(steps)):
outpath = f'{PATH}/cars/frame{i:06}.jpeg'
if os.path.exists(outpath):
continue
cond_embedding = slerp(t, text_embed_1, text_embed_2)
initial_noise = slerp(t, noise_1, noise_2)
latents = diffuse(cond_embedding, initial_noise)
image = latents_to_pil(latents)[0]
image.save(outpath)
After generate all the intermediate images, we can group them into a video or GIF as follows.
!ffmpeg -r 10 -f image2 -s 512x512 -i '{PATH}'/cars/frame%06d.jpeg -vcodec libx264 -crf 10 -pix_fmt yuv420p evolution_cars.mp4
!ffmpeg -f image2 -framerate 1 -i '{PATH}'/cars/frame%06d.jpeg -loop -1 evolution_cars.gif
!ls -alt evolution_cars.*
The resulting interpolation is illustrated in the video in at the top of this article.
That's all folks
Stable Diffusion is a powerful model with many applications. In this article we saw how to combine it with interpolation techniques to generated a compelling sequence of images.
I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.