Custom Training Loop
For most cases you would want to train your TensorFlow model using Keras API, i.e. model.compile
and model.fit
and its variation. This is basic and good enough, as you can specify the loss function, optimization algorithm, provide training/test data, and possibly a callback. Thought there are cases where you may want more control on the training process, for instance:
- new optimization algorithm
- easily modify gradients and how to calculate loss
- speedup training with all sort of tricks (e.g. teacher forcing)
- better hyperparameters tuning (e.g. use cyclic learning rate).
TensorFlow allow such customization through the GradientTape
API. Here is a typical example:
@tf.function
def training_loop(epochs, train_dataset, valid_dataset):
# on every epoch run on entire training and validation
for epoch in range(epochs):
# enumerate the training set in batches
for (batch, (features, labels)) in enumerate(train_dataset):
train_loss = 0
with tf.GradientTape() as tape:
# forward pass in training mode
logits = model(features, training=True)
# caculate batch loss function
loss = loss_func(labels, logits)
# backprobagation: calculate the gradients and apply them for each layer
grads = tape.gradient(loss, model.trainable_variables)
# cumulate training loss
train_loss += optimizer.apply_gradients(zip(grads, model.trainable_variables))
# calculate loss on validation set
valid_loss = 0
for (batch, (features, labels)) in enumerate(valid_dataset):
# forward pass in inference mode
logits = model(features, training=False)
# cumulate validation loss
valid_loss += loss_func(labels, logits)
What the code above is doing, is for every epoch it enumerates over the entire dataset in bacthes. For every batch, it does:
- A forward pass and record every operation in a tape
- Calculate the loss with respect to the actual labels
- Use recorded operations to perform a backpropagation and calulcate gradients
- Use the optimizer to adjust the layers weights by applying the gradients
Once, the pass on the entire training set finishes, the training loop performs a forward pass on the entire validation set in batches. For every batch, it does a forward pass and make sure the model is in an inference mode and calculate the validation loss of this epoch. It cumulates the losses to determine the validation loss of the current epoch.