In this article, we will implement a simplified version of a diffusion model by abstracting away the main components of such a model. In particular, we will implement a simplified mininal version of:
- The process of generating noisy data
- A UNet model
- The training of a Diffusion model
- The sampling of images during inference
Let's import the necessary modules (note: you may need to install them).
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import Model
from tensorflow.keras.layers import Conv2D, MaxPool2D, UpSampling2D
from tqdm.auto import tqdm
from PIL import Image
import matplotlib.pyplot as plt
Set the seed for reproducibility
seed = 123
tf.random.set_seed(seed)
np.random.seed(seed)
Data
We will be building a small model that's easy to understand, so to train it we will also use a small dataset. MNIST is a very basic computer vision dataset and a perfect candidate for this task. This datasets consists of greyscale images of 28x28
pixels with values ranging from 0
to 255
and representing digits 0
to 9
.
This datasets is hosted in TensorFlow Datasets, let's download it:
ds = tfds.load('mnist')
We can merge the train and test sets to make use of more images for training, this is optional though.
merge_ds = ds['train'].concatenate(ds['test'])
The following is a helper function that will come handy when we will plot the original data vs. noisy or predicted images.
def make_grid(xs, rows=1, cols=8):
xs = xs.numpy().squeeze()
images = [Image.fromarray(x) for x in xs]
return image_grid(images, rows, cols, 'L')
def image_grid(imgs, rows, cols, mode='RGB'):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new(mode, size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
Let's have a look at few samples from the dataset
batch = merge_ds.map(lambda x: x['image']).take(8)
xs = tf.convert_to_tensor(list(batch), dtype=tf.float32)
plt.imshow(make_grid(xs), cmap='Greys');
The train dataset is constructed as follows
def preprocess_fn(entry):
image = float(entry['image']) / 255.
image = tf.where(image > .5, 1.0, 0.0)
image = tf.cast(image, tf.float32)
return image
bs = 256
train_ds = merge_ds.map(preprocess_fn).shuffle(1024).batch(bs).prefetch(tf.data.AUTOTUNE)
Noise
One of main component in diffusion is the generation of noisy images and the reverse process which consits of removing noise from an image. This is also known as Scheduler
because it schedules over many steps the amount of noise to add (or remove). Implementing a noise scheduler is out of scope as it is complex, instead we will implement a simple function that modifies an image by an random value like this:
noise = tf.random.uniform(image)
noisy_x = (1 - amount) * image + amount * noise
The basic idea is that:
- If we use an
amount
close to0
then we will get back the input without modification. - If instead the
amount
is close to1
, the output will be complete noise with no information retained from the input. - If we use an
amount
between0
and1
, this will mix the input with noise and the output will be in the same range (0 to 1).
This is easily implemented by the following function that corrupts the input by mixing it with noise:
def corrupt(xs, amount):
noise = tf.random.uniform(xs.shape)
amount = tf.reshape(amount, (-1, 1, 1, 1)) # Sort shape so broadcasting works
return xs*(1-amount) + noise * amount
Let's visualize the output of this corruption process to better understand what the above logic is actually doing:
# Plotting the input data
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(make_grid(xs), cmap='Greys')
# Adding noise
amount = tf.linspace(0.0, 1.0, xs.shape[0]) # Left to right -> more corruption
noised_xs = corrupt(xs, amount)
# Plotting the noised version
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(make_grid(noised_xs), cmap='Greys');
We can easily confirm that an amount close to 0
keep the digit as is, while an amount close to 1
results in a very noisy output. Anything in between, will mixes the image with some noise.
Model
Our goal is to denoise an image, i.e. generate a version of the input image where we will remove some amount of noise. We can use the UNet architecture for this purpose. UNet has the following components:
- A 'constricting path' through which data is compressed
- An 'expanding path' through which data is expanded to the original dimension
- A set of skip connections at different levels to pass information and gradients.
The following class implements a very basic UNet model that takes in a one-channel image and passes it through:
- Three convolutional layers on the down path
- A max pooling layer for downsampling
- An
UpSampling2D
layer for upsampling - Three convolutional layers on the up path
- Skip connections between the down and up layers.
class BasicUNet(Model):
def __init__(self, in_channels=1, out_channels=1):
super(BasicUNet, self).__init__(name='basic-unet')
self.down_layers = [
Conv2D(32, kernel_size=5, padding="same"),
Conv2D(64, kernel_size=5, padding="same"),
Conv2D(64, kernel_size=5, padding="same"),
]
self.up_layers = [
Conv2D(64, kernel_size=5, padding="same"),
Conv2D(32, kernel_size=5, padding="same"),
Conv2D(out_channels, kernel_size=5, padding="same"),
]
self.act = tf.keras.activations.swish # also know as SiLU
self.downscale = MaxPool2D(2)
self.upscale = UpSampling2D(size=2)
def call(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = self.act(l(x))
# Store x for skip connection and downscale for all but the third (final) down layer
if i < 2:
h.append(x)
x = self.downscale(x)
for i, l in enumerate(self.up_layers):
# Fetch x for skip connection and Upscale for all except the first up layer
if i > 0:
x = self.upscale(x)
x += h.pop()
x = self.act(l(x))
return x
Let's verify that the output shape of the model is the same as its input shape
net = BasicUNet()
x = tf.random.uniform((8, 28, 28, 1))
net(x).shape
Let's examine the number of trainable parameters in this network
net.summary()
Training
To train the previous model to denoise images, we will train it for few epochs and each time:
- Fetch a batch of images
- Apply random noise to each image in this batch
- Pass the noisy images through the model
- Compare the predicted denoised images to the original ones
- Calculate the loss using
MeanSquaredError
- Update the model's parameters based on this loss.
This is implemented as follows:
epochs = 5
# Define a loss finction
loss_fn = tf.keras.losses.MeanSquaredError()
# Define an optimizer
opt = tf.keras.optimizers.Adam(learning_rate=1e-3)
# Record the losses
losses, avg_losses = [], []
# Iterate over epochs.
for epoch in tqdm(range(epochs)):
# Iterate over the batches of the dataset.
for step, xb in enumerate(train_ds):
with tf.GradientTape() as tape:
# Create noisy version of the input
noise_amount = tf.random.uniform((xb.shape[0],))
noisy_xb = corrupt(xb, noise_amount)
# Get the model prediction
pred = net(noisy_xb)
# Calculate the loss to determine how close the output is to the input
loss = loss_fn(pred, xb)
grads = tape.gradient(loss, net.trainable_weights)
opt.apply_gradients(zip(grads, net.trainable_weights))
# Store the loss
losses.append(loss.numpy())
# Calculate the average loss for this epoch
avg_loss = sum(losses[-len(xb):])/len(xb)
avg_losses.append(avg_loss)
# View the loss curve
plt.plot(losses)
plt.ylim(0, 0.1);
Let's see the result of training our model by visualing the predictions and comparing them to the original images:
# Fetch some data (using the first 8 for easy plotting)
batch = merge_ds.map(lambda x: x['image']).take(8)
xs = tf.convert_to_tensor(list(batch), dtype=tf.float32)
# Corrupt the images with a range of amounts
amount = tf.linspace(0.0, 1.0, xs.shape[0])
noised_xs = corrupt(xs, amount)
# Get the model predictions
preds = net(noised_xs)
# Plot
fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(make_grid(xs), cmap='Greys')
axs[1].set_title('Corrupted data')
axs[1].imshow(make_grid(noised_xs), cmap='Greys')
axs[2].set_title('Network Predictions')
axs[2].imshow(make_grid(preds), cmap='Greys');
The model seems to be able to reconstruct images with small amout of noise (close to 0
) but do less well when the amout is high (close to 1
).
Sampling
Because denoising an image with high amount of noise at one go is a hard task, we can simplify this by break this process into few steps (e.g. 10
) and move every time by a small step (e.g. 1/10
'th). Concretly, we start from random noise, then at each step we feed x
to our model then
we move x
by a small amount toward this step predictions (e.g. 10%). This way we could capture some hints about the structure of the image and improve the output at each step.
n_steps = 5
# Start from random noise
x = tf.random.uniform((8, 28, 28, 1))
step_history = [x]
pred_output_history = []
for i in range(n_steps):
# Predict the denoised x0
pred = net(x)
# Store model output for plotting
pred_output_history.append(pred)
# How much we move towards the prediction # Move part of the way there
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor
# Store step for plotting
x = tf.clip_by_value(x, 0, 1)
step_history.append(x)
fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('model prediction')
for i in range(n_steps):
axs[i, 0].imshow(make_grid(step_history[i]), cmap='Greys')
axs[i, 1].imshow(make_grid(pred_output_history[i]), cmap='Greys')
In the above visualization, you can see on the left the input at each of the 5 steps and on the right side the actual model prediction. It seems that the model predicts is getting a shape as we run for more steps. Let's try splitting the process into more steps as follows:
n_steps = 50
x = tf.random.uniform((64, 28, 28, 1))
for i in range(n_steps):
noise_amount = tf.ones((x.shape[0], )) * (1-(i/n_steps)) # Starting high going low
pred = net(x)
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor
x = tf.clip_by_value(x, 0, 1)
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(make_grid(x, rows=8), cmap='Greys');
This is not great result, but there are some recognizable digits (e.g. 5). Few things we could try to improve our model performance: training longer, schedule the learning rate instead of using fixed amount, use a complex architecture, etc.
That's all folks
Most often, to better understand how something complex works, it is helpful to abstract the complex components and consider a simpler version. In this article, we implemented the different components of a simplified diffusion model in TensorFlow and trained it on MNIST.
I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.