Diffusion models are some of the most recent disrubtive models. Outperforming generative models, they have been made popular as result of the success of DALL-E 2 or Imagen to generate photorealistic images when prompted on text.
Thanks to Hugging Face, Diffusion models are available to anyone to use via the diffusers library. In this post, we will explore the diffusers
Flax API to generate images from a prompt.
%%capture
%%bash
pip install --upgrade diffusers transformers scipy
pip install --upgrade flax
The diffusers librar contains the Flax implementation of the model, but we need to also use some weights to initialize it. We will use checkpoints from CompVis/stable-diffusion-v1-4, but first we need to accept the terms of the licence for this model.
One running the following cell, it will asks for an access token which you can get from the settings page of your Hugging Face account - link.
!huggingface-cli login
We need to setup JAX to use TPU to, see link.
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
We need to make sure that JAX is using TPU as backend before proceeding further
assert jax.device_count() == 8
Import the libraries we will use, e.g. matplotlib for plotting the resulting images.
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
from diffusers import FlaxStableDiffusionPipeline
pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jax.numpy.bfloat16
)
If you want to use the orignal precision (i.e. float32) then you need to change how the pipeline is loaded to this:
pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax", dtype=jax.numpy.bfloat32
)
def prepare_inputs(prompt, num_samples):
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids = shard(prompt_ids)
return prompt_ids
The next helper function will run the pipeline with the input prompt text. This will generate a JAX represnetation of the image corresponding to that prompt. We will at the end covert this JAX array into an actual PIL image.
def generate1(prompt, num_samples, seed=0, num_inference_steps=50, guidance_scale=7.5):
prng_seed = jax.random.PRNGKey(seed)
prompt_ids = prepare_inputs(prompt, num_samples)
# shard inputs and rng
params = replicate(pipeline_params)
prng_seed = jax.random.split(prng_seed, num_samples)
output = pipeline(prompt_ids, params, prng_seed, num_inference_steps, guidance_scale=guidance_scale, jit=True)
images = output.images
images = np.asarray(images.reshape((num_samples,) + images.shape[-3:]))
images = pipeline.numpy_to_pil(images)
return images
Our final helper function is for plotting the model output images into a grid.
def ceildiv(a, b):
return -(-a // b)
def plots_pil_images(pil_images, figsize=(10,5), rows=1, cols=None, titles=None, maintitle=None):
f = plt.figure(figsize=figsize)
if maintitle is not None: plt.suptitle(maintitle, fontsize=10)
cols = cols if cols else ceildiv(len(pil_images), rows)
for i in range(len(pil_images)):
sp = f.add_subplot(rows, cols, i+1)
sp.axis('Off')
if titles is not None: sp.set_title(titles[i], fontsize=16)
img = np.asarray(pil_images[i])
plt.imshow(img)
The main input to the Stable Diffusion pipeline is the text describing what we want the model to render, aka prompt.
prompt = "A road across trees with snow and moutains in the horizon in fresco style"
You can try your own prompts, for instance:
A city, morning sunrise, clouds, beautiful, summer, calm
Paris by night, studio ghibli, art by hayao miyazaki
Hyperrealist photo of a ford mustang
. . .
You can also try other painting style like: Expressionist, Oil, Surrealism. See more here - link
The only limit is our imagination. Note that you may need many iteration on your promot to endup with an image close to what you actually want.
num_samples = jax.device_count()
Let's pass our prompt to the pipeline and examine the different images that the model generated.
%%time
images = generate1(prompt, num_samples, seed=0)
plots_pil_images(images, figsize=(16, 8), rows=2, cols=num_samples/2)
We can changing the seed which will result in the model returning completely different images for the same prompt
%%time
images = generate1(prompt, num_samples, seed=13)
plots_pil_images(images, figsize=(16, 8), rows=2, cols=num_samples/2)
Another parameter to control the look and quality of the resulting image is guidance_scale
. To better understand what this value does we need to understand how the model generate an image.
Stable Diffusion is a multi-step model. At each step, the model predicts some noise using the input prompt and combines this with noise generating from blank input (i.e. empty string) as follows:
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
In this case, noise_pred_uncond
is the noise predicted from a blank input, and noise_pred_text
is the noise predicted from the input prompt.
The parameter guidance_scale
controls how mush difference (or we can think of it as distance) of noise_pred_uncond
to noise_pred_text
the model we will incorporate in our final image. The value of guidance_scale
can be anything but usual 7.5
seem to provide good results.
Let's try 1.1
as a value for our guidance and examine the resulting images.
%%time
images = generate1(prompt, num_samples, seed=0, guidance_scale=1.1)
plots_pil_images(images, figsize=(16, 8), rows=2, cols=num_samples/2)
You may notice that the resulting images are not as close to our prompt then the ones generated earlier using the default value of 7.5
.
Negative prompt
The Stable Diffusion pipeline accepts a paramter called neg_prompt_ids
. This is basically the Token IDs of a Negative prompt.
In simple terms, a negative prompt instructs the Stable Diffusion model to not include certain things in the generated image. This allow us to remove any object, styles, or abnormalities from the original generated image.
def generate2(prompt, neg_prompt, num_samples, seed=0, num_inference_steps=50, guidance_scale=7.5):
prng_seed = jax.random.PRNGKey(seed)
prompt_ids = prepare_inputs(prompt, num_samples)
neg_prompt_ids = prepare_inputs(neg_prompt, num_samples)
# shard inputs and rng
params = replicate(pipeline_params)
prng_seed = jax.random.split(prng_seed, num_samples)
output = pipeline(prompt_ids, params, prng_seed, num_inference_steps, guidance_scale=guidance_scale, jit=True,
neg_prompt_ids=neg_prompt_ids)
images = output.images
images = np.asarray(images.reshape((num_samples,) + images.shape[-3:]))
images = pipeline.numpy_to_pil(images)
return images
Let's tell Stable Diffusion to not include the sun
%%time
images = generate2(prompt, 'shiny sun', num_samples, seed=0)
plots_pil_images(images, figsize=(16, 8), rows=2, cols=num_samples/2)
Let's make sure Stable Diffusion does not generate images of sunnet
%%time
images = generate2(prompt, 'sunset', num_samples, seed=0)
plots_pil_images(images, figsize=(16, 8), rows=2, cols=num_samples/2)
Let's try not having snow storms, and notice how the model removed snow from the resulting images.
%%time
images = generate2(prompt, 'snow storm', num_samples, seed=0)
plots_pil_images(images, figsize=(16, 8), rows=2, cols=num_samples/2)
That's all folks
Stable Diffusion is a very neat model, and Hugging Face's diffusers library makes it very easy to play with such models. Furthermore, with the Flax implementation and the use of TPUs we can generate images and play with the model in an almost interactive way.
I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.