Clustering techniques are unsupervised learning algorithms that try to group unlabelled data that look similar into groups (also called clusters). In this article, we will have a detailed look at one of those techniques and then implement it in JAX.
First, lets import JAX and other needed libraries, then initialize JAX random number generator.
from functools import partial
import jax
import jax.numpy as jnp
import math
import matplotlib.pyplot as plt
seed = 123
key = jax.random.PRNGKey(seed)
jnp.set_printoptions(precision=3, threshold=5, linewidth=200)
We need some data to illustrate how the clustering algorithm works. We could download something from the internet or randomly generate observations.
n_clusters = 7
n_samples = 250
To generate our data, we're going to pick 7 random points that represent the actual clusters centroids, then for each one of those centroids we generate few random points around them.
centroids = jax.random.uniform(key, shape=(n_clusters, 2))*100 - 50
The observations around a centroid will be randomly sampled using Multivariate normal distribution. Which as the name suggests, will allow us to generate an observation vector where each element is randomly sampled following Normal distribution.
JAX allows to sample following Multivariate normal distribution thanks to jax.random.multivariate_normal().
def sample(mean):
cov = jnp.diag(jnp.array([5, 5]))
shape = [n_samples]
return jax.random.multivariate_normal(key, mean, cov, shape)
sample(centroids[0]).shape, centroids[0].shape
slices = [sample(mean) for mean in centroids]
data = jnp.concatenate(slices)
data.shape
To have a better sense around the generated data, we plot each cluster and its centroid as follows.
def plot_data(centroids, data, n_samples, ax=None):
if not ax: _, ax = plt.subplots()
for i, centroid in enumerate(centroids):
samples = data[n_samples*i: n_samples*(i+1)]
ax.scatter(samples[:, 0], samples[:, 1], s=1) # plot samples
ax.plot(*centroid, markersize=10, marker='x', color='k', mew=5)
ax.plot(*centroid, markersize=10, marker='x', color='m', mew=2)
plot_data(centroids, data, n_samples)
We can also plot the data along the central point which has equal distance to each point in the data.
midp = data.mean(axis=0)
midp
plot_data([midp]*n_clusters, data, n_samples)
Mean shift is a less known clustering algorithm that has some interesting advantages compared to the more popular k-means algorithm:
- Instead of requiring the exact number of clusters ahead of time, it requires a bandwidth to be specified, which can be easily chosen automatically
- Out of the box is able to handle clusters of any shape (e.g. circles or moon shapes like below), whereas k-means (without using special extensions) can properly handle only clusters of a ball shape.
from sklearn.datasets import make_circles, make_moons
# Cricles
X1 = make_circles(factor=0.5, noise=0.05, n_samples=1500)
# Moons
X2 = make_moons(n_samples=1500, noise=0.05)
fig, ax = plt.subplots(1, 2)
for i, X in enumerate([X1, X2]):
fig.set_size_inches(11, 5)
ax[i].scatter(X[0][:, 0], X[0][:, 1])
plt.tight_layout();
The MeanShift clustering algorithm works as follows:
- For each data point $x_i$ in the sample $X$, find the distance $d_{ij}$ between $x_i$ and every other point $x_j$ in $X$. i.e. $d_{ij} = \| x_i - x_j \|$
- Calculate weights $w_{ij}$ for each point $x_i$ in $X$ by applying the Gaussian kernel (with standard deviation set to bandwidth) to that point's distance to $x_j$.
- Update x as the weighted average of all other points in X, weighted based on the previous step
The algorithm converge iteratively by pushing closer points even closer until they are next to each other.
Note:
- This weighting approach penalizes points further away from each other
- The rate at which the weights fall to zero is determined by the bandwidth.
- The value of bandwidth should be choosen so that it covers one third of the data.
Distance
The first component of MeanShift is the distance function, which is simply the Euclidean distance (also known as Norm 2 distance) and defined as follows: $ d\left( x_i, x_j \right) = \sqrt {\sum _{k=1}^{K} \left( x_{ik}-x_{jk}\right)^2 } $ where $x_i$ and $x_j$ are two observation arrays of dimension $K$.
The rest of this section implments the Norm2 distance in JAX.
X = data.clone()
x0 = data[0]
x0.shape, X.shape
X[None].shape, X[:, None].shape, (X[None]-X[:, None]).shape
dist = jnp.sqrt(((X[None]-X[:, None])**2).sum(axis=1))
dist.shape
def distance(X, x):
diff = (X - x) if len(x.shape) == 1 else (X[None]-x[:, None])
return jnp.sqrt((diff**2).sum(axis=-1))
X[2].shape, X[:2].shape, X[:2][:, None].shape
distance(X, X).shape, distance(X, X[:10]).shape, distance(X, X[0]).shape
MeanShift uses the Gaussian kernel to calculate the weights by applying it to the distance between $x_i$ and $x_j$ as follows $w_{ij} = \varphi(\| x_i - x_j \|)$. It is defined by the following equation:
$$\varphi(z) = \frac{1}{\sigma*\sqrt{2\pi}}e^{\frac{-z^2}{2*\sigma^2}}$$
In JAX, it is implemented as follows:
def gaussian(x, bandwidth, mean=0):
return jnp.exp(-0.5 * ((x-mean)/bandwidth)**2) / (bandwidth*jnp.sqrt(2*math.pi))
Let's plot the above function to have a better sence of how its output looks like
def plot_func(f):
x=jnp.linspace(0, 10, 100)
plt.plot(x, f(x))
With a bandwidth of value 2.5 we get the following plot.
plot_func(partial(gaussian, bandwidth=2.5))
Notice how the output of the gaussion follows a decreasing line then literally becomes 0 for input greater or equal to 8. In fact, we can approximate this Gaussion with a much faster to calculate function defined as follows:
def tri(x, i):
return (-x + i).clip(0)/i
You can see from the plot that the output looks very similar to a gaussian.
plot_func(partial(tri, i=8))
Before going further let's try the gaussian on some inputs for validation
dist_0 = jnp.sqrt(((x0-X)**2).sum(axis=1))
dist_0.shape, dist_0
weight_0 = gaussian(dist_0, 2.5)
weight_0.shape, weight_0
weight_0[:,None] * X
Now we can finally claculate the weights
weight = gaussian(distance(X, X[:10]), 2)
weight_tri = tri(distance(X, X[:10]), 8)
weight.shape, weight_tri.shape, X.shape
The weight matrix is used in the Mean Shift algorithm to normalize X as follows:
num = jnp.dot(weight, X)
div = weight.sum(-1, keepdims=True)
X_out = num/div
num.shape, div.shape, X_out.shape, X_out
weight.shape, weight.sum(axis=1).shape
After defining all the components, the following method group them to apply one step of the Mean Shift algorithm on a batch of oberservations
def batched_meanshift_fn(X, bw=2):
@jax.jit
def apply(Xb):
wb = gaussian(distance(X, Xb), bw)
Xb_out = jnp.dot(wb, X) / wb.sum(-1, keepdims=True)
return Xb_out
return apply
For reference this is the expected shape of each of the vectors that the above function manipulate:
array | shape |
---|---|
X |
(N, 2) |
Xb |
(batch_size, 2) |
wb |
(batch_size, N) |
Xb_out |
(batch_size, 2) |
func = batched_meanshift_fn(X, 2)
func(X[:10])
func = batched_meanshift_fn(X, 2)
func(X[0])
Even if it is slower, we should first try the algorithm on manually batched data to check that the final result matches the expectations
def meanshift_step_1(step, args):
X, bs, bw = args
n = X.shape[0]
batches = []
batch_apply = batched_meanshift_fn(X, bw)
for i in range(0, n, bs):
s = slice(i, min(i+bs, n))
Xb = batch_apply(X[s])
batches.append(Xb)
X_out = jnp.concatenate(batches, axis=0)
return (X_out, bs, bw)
def meanshift_1(data, bs=500, bw=2, steps=5):
X = data.clone()
Xs = [X]
for i in range(steps):
X, _, _ = meanshift_step_1(i, (X, bs, bw))
Xs.append(X)
return X, Xs
Check how long it takes to run this implementation using the default batch size of 500.
%%time
X_out, _ = meanshift_1(data)
plot_data(centroids+3, X_out, n_samples)
The implementation is slower when using fews observations by batch. It is the worst when using the lowest batch size of 1.
%%time
X_out, _ = meanshift_1(data, 1)
plot_data(centroids+3, X_out, n_samples)
Increasing the batch size, the algorithm finishes earlier.
%%time
X_out, _ = meanshift_1(data, 1000)
plot_data(centroids+3, X_out, n_samples)
We can mush faster when using JAX vectorizing map vmap. This will allow us to run a function on each element of the array in parallel.
def meanshift_step_2(step, args):
X, bw = args
func = batched_meanshift_fn(X, bw)
X_out = jax.vmap(func)(X)
return (X_out, bw)
def meanshift_2(data, bw=2, steps=5):
X = data.clone()
X, _ = jax.lax.fori_loop(0, steps, meanshift_step_2, (X, bw))
return X
%%time
X_out = meanshift_2(data, n_clusters*n_samples)
plot_data(centroids+3, X_out, n_samples)
Because Mean Shift is an iterative algorithm, we can visualize how the clusters change on every step.
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
_, Xs = meanshift_1(data, 500)
def do_one(d):
X = Xs[d]
ax.clear()
plot_data(centroids+3, X, n_samples, ax=ax)
fig,ax = plt.subplots()
ani = FuncAnimation(fig, do_one, frames=5, interval=500, repeat=False)
plt.close()
HTML(ani.to_jshtml())
That's all folks
We have seen that MeanShift can be easily implemented in JAX. Similarly we could easily implement with JAX any of the other clustering algorithms: k-means clustering, dbscan, locality sensitive hashing.
I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.