In this article, we will implement a simplified version of a diffusion model by abstracting away the main components of such a model. In particular, we will implement a simplified mininal version of:

  • The process of generating noisy data
  • A UNet model
  • The training of a Diffusion model
  • The sampling of images during inference

Setup

Let's import the necessary modules (note: you may need to install them).

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import Model
from tensorflow.keras.layers import Conv2D, MaxPool2D, UpSampling2D
from tqdm.auto import tqdm
from PIL import Image
import matplotlib.pyplot as plt

Set the seed for reproducibility

seed = 123
tf.random.set_seed(seed)
np.random.seed(seed)

Data

We will be building a small model that's easy to understand, so to train it we will also use a small dataset. MNIST is a very basic computer vision dataset and a perfect candidate for this task. This datasets consists of greyscale images of 28x28 pixels with values ranging from 0 to 255 and representing digits 0 to 9.

This datasets is hosted in TensorFlow Datasets, let's download it:

ds = tfds.load('mnist')

We can merge the train and test sets to make use of more images for training, this is optional though.

merge_ds = ds['train'].concatenate(ds['test'])

The following is a helper function that will come handy when we will plot the original data vs. noisy or predicted images.

def make_grid(xs, rows=1, cols=8):
    xs = xs.numpy().squeeze()
    images = [Image.fromarray(x) for x in xs]
    return image_grid(images, rows, cols, 'L')

def image_grid(imgs, rows, cols, mode='RGB'):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new(mode, size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

Let's have a look at few samples from the dataset

batch = merge_ds.map(lambda x: x['image']).take(8)
xs = tf.convert_to_tensor(list(batch), dtype=tf.float32)
plt.imshow(make_grid(xs), cmap='Greys');

The train dataset is constructed as follows

def preprocess_fn(entry):
    image = float(entry['image']) / 255.
    image = tf.where(image > .5, 1.0, 0.0)
    image = tf.cast(image, tf.float32)
    return image

bs = 256

train_ds = merge_ds.map(preprocess_fn).shuffle(1024).batch(bs).prefetch(tf.data.AUTOTUNE)

Noise

One of main component in diffusion is the generation of noisy images and the reverse process which consits of removing noise from an image. This is also known as Scheduler because it schedules over many steps the amount of noise to add (or remove). Implementing a noise scheduler is out of scope as it is complex, instead we will implement a simple function that modifies an image by an random value like this:

noise = tf.random.uniform(image)
noisy_x =  (1 - amount) * image + amount * noise

The basic idea is that:

  • If we use an amount close to 0 then we will get back the input without modification.
  • If instead the amount is close to 1, the output will be complete noise with no information retained from the input.
  • If we use an amount between 0 and 1, this will mix the input with noise and the output will be in the same range (0 to 1).

This is easily implemented by the following function that corrupts the input by mixing it with noise:

def corrupt(xs, amount):
    noise = tf.random.uniform(xs.shape)
    amount = tf.reshape(amount, (-1, 1, 1, 1)) # Sort shape so broadcasting works
    return xs*(1-amount) + noise * amount

Let's visualize the output of this corruption process to better understand what the above logic is actually doing:

# Plotting the input data
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(make_grid(xs), cmap='Greys')

# Adding noise
amount = tf.linspace(0.0, 1.0, xs.shape[0]) # Left to right -> more corruption
noised_xs = corrupt(xs, amount)

# Plotting the noised version
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(make_grid(noised_xs), cmap='Greys');

We can easily confirm that an amount close to 0 keep the digit as is, while an amount close to 1 results in a very noisy output. Anything in between, will mixes the image with some noise.

Model

Our goal is to denoise an image, i.e. generate a version of the input image where we will remove some amount of noise. We can use the UNet architecture for this purpose. UNet has the following components:

  • A 'constricting path' through which data is compressed
  • An 'expanding path' through which data is expanded to the original dimension
  • A set of skip connections at different levels to pass information and gradients.

The following class implements a very basic UNet model that takes in a one-channel image and passes it through:

  • Three convolutional layers on the down path
  • A max pooling layer for downsampling
  • An UpSampling2D layer for upsampling
  • Three convolutional layers on the up path
  • Skip connections between the down and up layers.
class BasicUNet(Model):
    def __init__(self, in_channels=1, out_channels=1):
        super(BasicUNet, self).__init__(name='basic-unet')
        self.down_layers = [
            Conv2D(32, kernel_size=5, padding="same"),
            Conv2D(64, kernel_size=5, padding="same"),
            Conv2D(64, kernel_size=5, padding="same"),
        ]
        self.up_layers = [
            Conv2D(64, kernel_size=5, padding="same"),
            Conv2D(32, kernel_size=5, padding="same"),
            Conv2D(out_channels, kernel_size=5, padding="same"), 
        ]
        self.act = tf.keras.activations.swish # also know as SiLU
        self.downscale = MaxPool2D(2)
        self.upscale = UpSampling2D(size=2)

    def call(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x))
            # Store x for skip connection and downscale for all but the third (final) down layer
            if i < 2:
              h.append(x)
              x = self.downscale(x)
              
        for i, l in enumerate(self.up_layers):
            # Fetch x for skip connection and Upscale for all except the first up layer
            if i > 0:
              x = self.upscale(x)
              x += h.pop()
            x = self.act(l(x))
            
        return x

Let's verify that the output shape of the model is the same as its input shape

net = BasicUNet()
x = tf.random.uniform((8, 28, 28, 1))
net(x).shape
TensorShape([8, 28, 28, 1])

Let's examine the number of trainable parameters in this network

net.summary()
Model: "basic-unet"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d_24 (Conv2D)          multiple                  832       
                                                                 
 conv2d_25 (Conv2D)          multiple                  51264     
                                                                 
 conv2d_26 (Conv2D)          multiple                  102464    
                                                                 
 conv2d_27 (Conv2D)          multiple                  102464    
                                                                 
 conv2d_28 (Conv2D)          multiple                  51232     
                                                                 
 conv2d_29 (Conv2D)          multiple                  801       
                                                                 
 max_pooling2d_4 (MaxPooling  multiple                 0         
 2D)                                                             
                                                                 
 up_sampling2d_4 (UpSampling  multiple                 0         
 2D)                                                             
                                                                 
=================================================================
Total params: 309,057
Trainable params: 309,057
Non-trainable params: 0
_________________________________________________________________

Training

To train the previous model to denoise images, we will train it for few epochs and each time:

  • Fetch a batch of images
  • Apply random noise to each image in this batch
  • Pass the noisy images through the model
  • Compare the predicted denoised images to the original ones
  • Calculate the loss using MeanSquaredError
  • Update the model's parameters based on this loss.

This is implemented as follows:

epochs = 5

# Define a loss finction
loss_fn = tf.keras.losses.MeanSquaredError()

# Define an optimizer
opt = tf.keras.optimizers.Adam(learning_rate=1e-3) 

# Record the losses
losses, avg_losses = [], []

# Iterate over epochs.
for epoch in tqdm(range(epochs)):

    # Iterate over the batches of the dataset.
    for step, xb in enumerate(train_ds):
        with tf.GradientTape() as tape:
            # Create noisy version of the input
            noise_amount = tf.random.uniform((xb.shape[0],))
            noisy_xb = corrupt(xb, noise_amount)
            
            # Get the model prediction
            pred = net(noisy_xb)

            # Calculate the loss to determine how close the output is to the input
            loss = loss_fn(pred, xb)

        grads = tape.gradient(loss, net.trainable_weights)
        opt.apply_gradients(zip(grads, net.trainable_weights))

        # Store the loss
        losses.append(loss.numpy())

    # Calculate the average loss for this epoch
    avg_loss = sum(losses[-len(xb):])/len(xb)
    avg_losses.append(avg_loss)

# View the loss curve
plt.plot(losses)
plt.ylim(0, 0.1);

Let's see the result of training our model by visualing the predictions and comparing them to the original images:

# Fetch some data (using the first 8 for easy plotting)
batch = merge_ds.map(lambda x: x['image']).take(8)
xs = tf.convert_to_tensor(list(batch), dtype=tf.float32)

# Corrupt the images with a range of amounts
amount = tf.linspace(0.0, 1.0, xs.shape[0])
noised_xs = corrupt(xs, amount)

# Get the model predictions
preds = net(noised_xs)  

# Plot
fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(make_grid(xs), cmap='Greys')
axs[1].set_title('Corrupted data')
axs[1].imshow(make_grid(noised_xs), cmap='Greys')
axs[2].set_title('Network Predictions')
axs[2].imshow(make_grid(preds), cmap='Greys');

The model seems to be able to reconstruct images with small amout of noise (close to 0) but do less well when the amout is high (close to 1).

Sampling

Because denoising an image with high amount of noise at one go is a hard task, we can simplify this by break this process into few steps (e.g. 10) and move every time by a small step (e.g. 1/10'th). Concretly, we start from random noise, then at each step we feed x to our model then we move x by a small amount toward this step predictions (e.g. 10%). This way we could capture some hints about the structure of the image and improve the output at each step.

n_steps = 5

# Start from random noise
x = tf.random.uniform((8, 28, 28, 1))
step_history = [x]
pred_output_history = []

for i in range(n_steps):
    # Predict the denoised x0
    pred = net(x)
    # Store model output for plotting
    pred_output_history.append(pred)
    # How much we move towards the prediction # Move part of the way there
    mix_factor = 1/(n_steps - i)
    x = x*(1-mix_factor) + pred*mix_factor
    # Store step for plotting
    x = tf.clip_by_value(x, 0, 1)
    step_history.append(x)

fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('model prediction')
for i in range(n_steps):
    axs[i, 0].imshow(make_grid(step_history[i]), cmap='Greys')
    axs[i, 1].imshow(make_grid(pred_output_history[i]), cmap='Greys')

In the above visualization, you can see on the left the input at each of the 5 steps and on the right side the actual model prediction. It seems that the model predicts is getting a shape as we run for more steps. Let's try splitting the process into more steps as follows:

n_steps = 50
x = tf.random.uniform((64, 28, 28, 1))

for i in range(n_steps):
    noise_amount = tf.ones((x.shape[0], )) * (1-(i/n_steps)) # Starting high going low
    pred = net(x)
    mix_factor = 1/(n_steps - i)
    x = x*(1-mix_factor) + pred*mix_factor
    x = tf.clip_by_value(x, 0, 1)
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(make_grid(x, rows=8), cmap='Greys');

This is not great result, but there are some recognizable digits (e.g. 5). Few things we could try to improve our model performance: training longer, schedule the learning rate instead of using fixed amount, use a complex architecture, etc.

That's all folks

Most often, to better understand how something complex works, it is helpful to abstract the complex components and consider a simpler version. In this article, we implemented the different components of a simplified diffusion model in TensorFlow and trained it on MNIST.

I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.