In previous posts we saw how to easily generate images from text in few lines of code using FlaxStableDiffusionPipeline
(see link) and dived deep into the details of the diffusion loop (see link).
As we gained more more in depth understanding of Stable Diffusion, we can now be more dangerous and start experimenting. In this article, we will try to start the diffusion loop from a noised version of an input image instead of starting from random noise (aka image2image).
Setup and Imports
Same as in previous Stable Diffusion articles, let's install packages, accept license for using Stable Diffusion and import modules.
%%capture
%%bash
pip install --upgrade scipy flax
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade diffusers transformers
!huggingface-cli login
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel, FlaxPNDMScheduler
dtype = jax.numpy.bfloat16
model_id = "CompVis/stable-diffusion-v1-4"
revision = "bf16" # "flax"
tokenizer = CLIPTokenizer.from_pretrained(model_id, revision=revision, subfolder="tokenizer", dtype=dtype)
text_encoder = FlaxCLIPTextModel.from_pretrained(model_id, revision=revision, subfolder="text_encoder", dtype=dtype)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(model_id, revision=revision, subfolder="vae", dtype=dtype)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(model_id, revision=revision, subfolder="unet", dtype=dtype)
scheduler, scheduler_params = FlaxPNDMScheduler.from_pretrained(model_id, revision=revision, subfolder="scheduler")
Diffusion
For the diffusion we'll use a similar loop as in Stable Diffusion Deep Dive, except we will skip the first start_step
steps. We will use a random choosen image as a starting point, add some noise to it and then do the remaining few denoising steps in the loop.
seed = 123
num_samples = jax.device_count()
prng_seed = jax.random.PRNGKey(seed)
prng_seed = jax.random.split(prng_seed, num_samples)
Set the guidance factor and total number of inference steps
guidance_scale = 7.5 #@param {type:"slider", min:0, max:100, step:0.5}
num_inference_steps = 30 #@param
Define some parameters for the diffusion loop
# init
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
# call
height = unet.config.sample_size * vae_scale_factor
width = unet.config.sample_size * vae_scale_factor
text_encoder_params = None
!curl -s -o input_image.jpeg https://images.unsplash.com/photo-1670139015746-832eaa4460c1?ixlib=rb-4.0.3&dl=peter-thomas-mcV0gUPvGXE-unsplash.jpg&q=80&fm=jpg&crop=entropy&cs=tinysrgb&w=224&q=224
pil_image = Image.open('input_image.jpeg')
pil_image.resize((512, 512))
We need a helper function to encode a PIL image into its latent representation
def pil_to_latents(pil_image):
# Single image -> single latent in a batch (so size 1, 4, 64, 64)
image = np.asarray(pil_image)
image = jax.image.resize(image, (512, 512, 3), "bicubic")
image = (image / 127.5 - 1.0).astype(np.float32)
input_im = jnp.expand_dims(image, axis=0)
input_im = jnp.transpose(input_im, (0, 3, 1, 2))
# encode the image
latents = vae.apply({"params": vae_params}, input_im, method=vae.encode)
return 0.18215 * latents.latent_dist.sample(prng_seed)
We need a helper function to decode a latent representation into a PIL image
def latents_to_pil(latents):
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
images = vae.apply({"params": vae_params}, latents, method=vae.decode).sample
# convert JAX to numpy
images = (images / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
images = np.asarray(images)
# convert numpy array to PIL
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
Encode the image and check the shape of its latent representation
%%time
encoded = pil_to_latents(pil_image)
encoded.shape
encoded = jnp.transpose(encoded, (0, 3, 1, 2))
encoded.shape
We cannot use the encoded image as is, we need to add noise to it using the scheduler to a level equivalent to the target start step. As an example, let's visualize what it looks like to add a bit of noise of an image at the step 13.
scheduler_state = scheduler.set_timesteps(scheduler_params, num_inference_steps=15, shape=encoded.shape)
noise = jax.random.normal(prng_seed, shape=encoded.shape, dtype=jnp.float32) # Random noise
sampling_step = 13 # Equivalent to step 13 out of 15 in the schedule above
encoded_and_noised = scheduler.add_noise(encoded, noise, timesteps=scheduler_state.timesteps[sampling_step])
latents_to_pil(encoded_and_noised)[0] # Display
Choose a prompt
prompt = "a photo of abandoned cars in the desert"
Let's tokenize then encode to tokens into embeddings
%%time
# prepare_inputs
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
prompt_ids = text_input.input_ids
prompt_ids
%%time
# get prompt text embeddings
text_embeddings = text_encoder(prompt_ids, params=text_encoder_params)[0]
batch_size = prompt_ids.shape[0]
max_length = prompt_ids.shape[-1]
We also need the embeddings of the blank text
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
).input_ids
uncond_embeddings = text_encoder(uncond_input, params=text_encoder_params)[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
latents_shape = (
batch_size,
unet.in_channels,
height // vae_scale_factor,
width // vae_scale_factor,
)
latents_shape
Set the timestamps of the scheduler based on the number of steps
scheduler_state = scheduler.set_timesteps(
scheduler_params, num_inference_steps=num_inference_steps, shape=latents_shape
)
As illustarted earlier we cannot use the input image as is, we need to prepare its latents by adding the appropriate amount of noise that matches the start step.
start_step = 10
noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
# scale the initial noise by the standard deviation required by the scheduler
noise = noise * scheduler.init_noise_sigma
# apply noise to the latents of the input image
latents = scheduler.add_noise(encoded, noise, timesteps=scheduler_state.timesteps[start_step])
The diffusion step is unchanged from the Stable Diffusion original loop. For details check - Stable Diffusion Deep Dive.
def diffusion_step(step, args):
latents, scheduler_state = args
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = jnp.concatenate([latents] * 2)
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])
latents_input = scheduler.scale_model_input(scheduler_state, latents_input, t)
# predict the noise residual
noise_pred = unet.apply(
{"params": unet_params},
jnp.array(latents_input),
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=context,
).sample
# perform guidance
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, scheduler_state
Only difference is that instead of starting the diffusion loop from 0
we will start from 10
(or the step of your choice).
%%time
latents, _ = jax.lax.fori_loop(start_step, num_inference_steps, diffusion_step, (latents, scheduler_state))
Now let's decode the latents and inspect the model output image.
%%time
images = latents_to_pil(latents)
See how the model generate an image that corresponds to the text prompt and is very close to the input image.
images[0]
That's all folks
Stable Diffusion is a very cool model and can be easily customized to condition how the images are generate. In this article we saw how to use a source image and force the model to generate something that looks like it and corresponds to the prompt.
I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.