Flax provides a variety of Activation functions that can be used when building a neural netowrk. In this article of the Flax Basics series, we will explore some of the most used Activation functions.
!pip install -q flax
import numpy as np
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
import matplotlib.pyplot as plt
Set seed for reproducibility
seed = 123
key = jax.random.PRNGKey(seed)
x = np.linspace(-10, 10, 100, dtype=np.float32)
y = np.asarray(nn.relu(x))
Let's apply the relu
function on our test array and visualize the output.
plt.plot(x, y)
plt.legend(['relu'], loc='upper left')
plt.show();
def prelu(alpha=0.01):
prelu = nn.PReLU(param_dtype=jnp.float32, negative_slope_init=alpha)
def call(x):
variables = prelu.init(key, x)
return prelu.apply(variables, x)
return call
Let's apply the prelu
function on our test array usig different values for the slope parameter and visualize the outputs.
alphas = [-1.0, -0.1, -0.01, 0.01, 0.1, 1.0]
legends = [str(alpha) for alpha in alphas]
for alpha in alphas:
y = prelu(alpha)(x)
plt.plot(x, y)
plt.legend(legends, loc='lower right')
plt.show();
Let's apply the elu
function on our test array usig different values for the $\alpha$ parameter and visualize the outputs.
alphas = [-1.0, -0.1, -0.01, 0.01, 0.1, 1.0]
legends = [str(alpha) for alpha in alphas]
for alpha in alphas:
y = nn.elu(x, alpha)
plt.plot(x, y)
plt.legend(legends, loc='lower right')
plt.show();
Continuously-differentiable Exponential Linear Unit (CELU) was proposed in this paper. It is continuously differentiable version of ELU. It is defined with the following formula:
$$ \mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases} $$Let's apply the celu
function on our test array usig different values for the $\alpha$ parameter and visualize the outputs.
alphas = [-1.0, -0.1, -0.01, 0.01, 0.1, 1.0]
legends = [str(alpha) for alpha in alphas]
for alpha in alphas:
y = nn.celu(x, alpha)
plt.plot(x, y)
plt.legend(legends, loc='lower right')
plt.show();
GELU
Gaussian Error Linear Units (GELUs) is another variation of ReLU. It was first proposed in this paper. It is defined by the following formula:
$$ \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right) $$The calculation of GELU can be approximated with the following formula: $$ \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right) $$
Let's apply the gelu
function on our test array and visualize the output.
y = np.asarray(nn.gelu(x))
plt.plot(x, y)
plt.legend(['gelu'], loc='upper left')
plt.show();
GLU
Gated Linear Unit (GLU) is used mostly in Gated CNNs for natural language processing applications. It is based on sigmoid function and defined with the following formula:
$$ GLU(a,b)=a⊗σ(b) $$y = np.asarray(nn.glu(x))
plt.plot(y)
plt.legend(['glu'], loc='upper left')
plt.show();
Let's apply the sigmoid
function on our test array and visualize the output.
y = np.asarray(nn.sigmoid(x))
plt.plot(x, y)
plt.legend(['sigmoid'], loc='upper left')
plt.show();
y = np.asarray(nn.log_sigmoid(x))
plt.plot(x, y)
plt.legend(['log sigmoid'], loc='upper left')
plt.show();
Softmax
The softmax is usually used as the last activation layer of a multi-class classifier. This is because the output of a Softmax is probability distribution over the different classes. It is defined by the following function
$$ \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} $$Let's apply the softmax
function on our test array and visualize the output.
y = np.asarray(nn.softmax(x))
plt.plot(x, y)
plt.legend(['softmax'], loc='upper left')
plt.show();
Log Softmax
Log Softmax as the name suggests applies a log function to the output of a Softmax. It is defined by the following formula:
$$ \mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right) $$Let's apply the log_softmax
function on our test array and visualize the output.
y = np.asarray(nn.log_softmax(x))
plt.plot(x, y)
plt.legend(['log softmax'], loc='upper left')
plt.ylim((-25, 10))
plt.show();
y = nn.soft_sign(x)
plt.plot(x, y)
plt.legend(['soft sign'], loc='upper left')
plt.show();
y = nn.softplus(x)
plt.plot(x, y)
plt.legend(['soft plus'], loc='upper left')
plt.show();
Swish (SiLU)
Sigmoid Linear Unit (SiLU) was first proposed in this paper. SiLU is based on sigmoid and is defined by the following formula:
$$ \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}} $$Let's apply the swish
function on our test array and visualize the output.
y = nn.swish(x)
plt.plot(x, y)
plt.legend(['swish'], loc='upper left')
plt.show();
Custom activation
If all the available activation functions does not work for you. Flax makes it easy to define custom ones.
For example, let's implement the leaky relu activation function which is defined by the following formula:
$$ \mathrm{leakyrelu}(x) = \begin{cases} x, & x > 0\\ \alpha \cdot x, & x \le 0 \end{cases} $$class LeakyReLU(nn.Module):
alpha : float = 0.1
def __call__(self, x):
return jnp.where(x > 0, x, self.alpha * x)
Let's apply the Leaky ReLU
activation function that we just defiend on our test array and visualize the output.
y = LeakyReLU()(x)
plt.plot(x, y)
plt.legend(['leaky relu'], loc='upper left')
plt.show();
That's all folks
I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc.