Installation the jax/flax and optax modules

@ pip install jax
$ pip install flax
$ pip install optax

Import the modules

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
from flax.training import train_state
import optax
import tensorflow_datasets as tfds

Data pipeline

Flax does not provide an API for loading data, but we can build a data pipeline with TensorFlow data API and JAX. For instance, the following snippets load MNIST data into JAX numpy arrays:

ds = tfds.load('mnist')
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))

After that we can use JAX numpy to perform any required data processing: resizing, normalizing, cropping, etc.

Model definition

Models in Flax can be defined using the Setup function where we will need to initialize all the layers

class MyModel(nn.Module):
    def setup(self):
        self.lin = nn.Dense(10)

    def __call__(self, x):
        x = self.dense1(x)
        return x

Or we can define models in Flax using the @nn.compact annotation

class MyModel(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(10)(x)
        return x

Loss and metrics

Loss functions are available through the optax module, but can also be calculated manually using JAX numpy API

def cross_entropy_loss(*, logits, y_true):
    y_true_onehot = jax.nn.one_hot(y_true, num_classes=2)
    return optax.softmax_cross_entropy(logits=logits, labels=y_true_onehot).mean()

Metrics need to be defined manually, for instance the accuracy could be calculated like this using JAX numpy API

def compute_metrics(*, logits, y_true):
    accuracy = jnp.mean(jnp.argmax(logits, -1) == y_true)
    metrics = {
        'accuracy': accuracy,
    }
    return metrics

Training with Flax

Training models in Flax, requires first the creation of a TrainState to hold any information that will be passed to the model during its training.

def create_train_state(rng):
    model = MyModel()
    params = model.init(rng, param1, param2, ...)['params']
    opt = optax.adam(0.01,0.99,0.999,2e-05)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=opt)

Then, we define a training step where we do the forward pass, compute losses and corresponding gradients. After that, we use the gradients to update the model parameters.

@jax.jit
def train_step(training_state, xb, y_true):
  def loss_fn(params):
    logits = Model().apply({'params': params}, xb)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=jax.nn.one_hot(y_true, num_classes=10)))
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits, y_true)
  return state, metrics

Otherwise we could use the Elegy which is a high-level API similar to Keras.