In previous posts we saw how to easily generate images from text in few lines of code using FlaxStableDiffusionPipeline (see link) and dived deep into the details of the diffusion loop (see link).

As we gained more more in depth understanding of Stable Diffusion, we can now be more dangerous and start experimenting. In this article, we will try to start the diffusion loop from a noised version of an input image instead of starting from random noise (aka image2image).

image2image.png

Setup and Imports

Same as in previous Stable Diffusion articles, let's install packages, accept license for using Stable Diffusion and import modules.

%%capture
%%bash

pip install --upgrade scipy flax
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade diffusers transformers     
!huggingface-cli login
    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` now requires a token generated from https://huggingface.co/settings/tokens .
    
Token: 
Add token as git credential? (Y/n) n
Token is valid.
Your token has been saved to /root/.huggingface/token
Login successful
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel, FlaxPNDMScheduler

Loading the model

Download the checkpoints for Stable Diffusion and instantiate the model components.

dtype = jax.numpy.bfloat16
model_id = "CompVis/stable-diffusion-v1-4"
revision = "bf16" # "flax"
tokenizer = CLIPTokenizer.from_pretrained(model_id, revision=revision, subfolder="tokenizer", dtype=dtype)
text_encoder = FlaxCLIPTextModel.from_pretrained(model_id, revision=revision, subfolder="text_encoder", dtype=dtype)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(model_id, revision=revision, subfolder="vae", dtype=dtype)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(model_id, revision=revision, subfolder="unet", dtype=dtype)
scheduler, scheduler_params = FlaxPNDMScheduler.from_pretrained(model_id, revision=revision, subfolder="scheduler")

Diffusion

For the diffusion we'll use a similar loop as in Stable Diffusion Deep Dive, except we will skip the first start_step steps. We will use a random choosen image as a starting point, add some noise to it and then do the remaining few denoising steps in the loop.

Setup

First, lets set a random seed for reproducibility

seed = 123
num_samples = jax.device_count()
prng_seed = jax.random.PRNGKey(seed)
prng_seed = jax.random.split(prng_seed, num_samples)

Set the guidance factor and total number of inference steps

guidance_scale = 7.5 #@param {type:"slider", min:0, max:100, step:0.5}
num_inference_steps = 30 #@param 

Define some parameters for the diffusion loop

# init
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)

# call
height = unet.config.sample_size * vae_scale_factor
width = unet.config.sample_size * vae_scale_factor
text_encoder_params = None

Image

Get an image to use as the input

!curl -s -o input_image.jpeg https://images.unsplash.com/photo-1670139015746-832eaa4460c1?ixlib=rb-4.0.3&dl=peter-thomas-mcV0gUPvGXE-unsplash.jpg&q=80&fm=jpg&crop=entropy&cs=tinysrgb&w=224&q=224 
pil_image = Image.open('input_image.jpeg')
pil_image.resize((512, 512))

We need a helper function to encode a PIL image into its latent representation

def pil_to_latents(pil_image):
    # Single image -> single latent in a batch (so size 1, 4, 64, 64)
    image = np.asarray(pil_image)
    image = jax.image.resize(image, (512, 512, 3), "bicubic")
    image = (image / 127.5 - 1.0).astype(np.float32)
    input_im = jnp.expand_dims(image, axis=0)
    input_im = jnp.transpose(input_im, (0, 3, 1, 2))
    # encode the image
    latents = vae.apply({"params": vae_params}, input_im, method=vae.encode)
    return 0.18215 * latents.latent_dist.sample(prng_seed)

We need a helper function to decode a latent representation into a PIL image

def latents_to_pil(latents):
    # scale and decode the image latents with vae
    latents = 1 / 0.18215 * latents
    images = vae.apply({"params": vae_params}, latents, method=vae.decode).sample
    # convert JAX to numpy
    images = (images / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
    images = np.asarray(images)
    # convert numpy array to PIL
    images = (images * 255).round().astype("uint8")
    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image) for image in images]
    return pil_images

Encode the image and check the shape of its latent representation

%%time

encoded = pil_to_latents(pil_image)
CPU times: user 33.1 s, sys: 2.52 s, total: 35.6 s
Wall time: 24.6 s
encoded.shape
(1, 64, 64, 4)
encoded = jnp.transpose(encoded, (0, 3, 1, 2))
encoded.shape
(1, 4, 64, 64)

We cannot use the encoded image as is, we need to add noise to it using the scheduler to a level equivalent to the target start step. As an example, let's visualize what it looks like to add a bit of noise of an image at the step 13.

Note: you can try a different step number to have a sense of how deteriorate is the image.

scheduler_state = scheduler.set_timesteps(scheduler_params, num_inference_steps=15, shape=encoded.shape)
noise = jax.random.normal(prng_seed, shape=encoded.shape, dtype=jnp.float32) # Random noise
sampling_step = 13 # Equivalent to step 13 out of 15 in the schedule above
encoded_and_noised = scheduler.add_noise(encoded, noise, timesteps=scheduler_state.timesteps[sampling_step])
latents_to_pil(encoded_and_noised)[0] # Display

Prompt

Choose a prompt

prompt = "a photo of abandoned cars in the desert"

Let's tokenize then encode to tokens into embeddings

%%time 

# prepare_inputs
text_input = tokenizer(
            prompt,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="np",
        )
prompt_ids = text_input.input_ids
prompt_ids
CPU times: user 1.4 ms, sys: 0 ns, total: 1.4 ms
Wall time: 3.66 ms
array([[49406,   320,  1125,   539, 11227,  3346,   530,   518,  7301,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407]])
%%time

# get prompt text embeddings
text_embeddings = text_encoder(prompt_ids, params=text_encoder_params)[0]
CPU times: user 2.1 s, sys: 97.1 ms, total: 2.2 s
Wall time: 4.21 s
batch_size = prompt_ids.shape[0]
max_length = prompt_ids.shape[-1]

We also need the embeddings of the blank text

uncond_input = tokenizer(
  [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
).input_ids

uncond_embeddings = text_encoder(uncond_input, params=text_encoder_params)[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
latents_shape = (
            batch_size,
            unet.in_channels,
            height // vae_scale_factor,
            width // vae_scale_factor,
        )

latents_shape
(1, 4, 64, 64)

Diffusion

Set the timestamps of the scheduler based on the number of steps

scheduler_state = scheduler.set_timesteps(
    scheduler_params, num_inference_steps=num_inference_steps, shape=latents_shape
    )

As illustarted earlier we cannot use the input image as is, we need to prepare its latents by adding the appropriate amount of noise that matches the start step.

start_step = 10
noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
# scale the initial noise by the standard deviation required by the scheduler
noise = noise * scheduler.init_noise_sigma
# apply noise to the latents of the input image
latents = scheduler.add_noise(encoded, noise, timesteps=scheduler_state.timesteps[start_step])

The diffusion step is unchanged from the Stable Diffusion original loop. For details check - Stable Diffusion Deep Dive.

def diffusion_step(step, args):
    latents, scheduler_state = args
    # For classifier free guidance, we need to do two forward passes.
    # Here we concatenate the unconditional and text embeddings into a single batch
    # to avoid doing two forward passes
    latents_input = jnp.concatenate([latents] * 2)
    
    t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
    timestep = jnp.broadcast_to(t, latents_input.shape[0])
    
    latents_input = scheduler.scale_model_input(scheduler_state, latents_input, t)
    
    # predict the noise residual
    noise_pred = unet.apply(
        {"params": unet_params},
        jnp.array(latents_input),
        jnp.array(timestep, dtype=jnp.int32),
        encoder_hidden_states=context,
        ).sample
    # perform guidance
    noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
    
    # compute the previous noisy sample x_t -> x_t-1
    latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
    
    return latents, scheduler_state

Only difference is that instead of starting the diffusion loop from 0 we will start from 10 (or the step of your choice).

%%time

latents, _ = jax.lax.fori_loop(start_step, num_inference_steps, diffusion_step, (latents, scheduler_state))
CPU times: user 17min 29s, sys: 11 s, total: 17min 40s
Wall time: 9min 36s

Now let's decode the latents and inspect the model output image.

%%time

images = latents_to_pil(latents)
CPU times: user 1min, sys: 1.82 s, total: 1min 2s
Wall time: 33.9 s

See how the model generate an image that corresponds to the text prompt and is very close to the input image.

images[0]

That's all folks

Stable Diffusion is a very cool model and can be easily customized to condition how the images are generate. In this article we saw how to use a source image and force the model to generate something that looks like it and corresponds to the prompt.

I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.