PyTorch tarining loop and callbacks

A basic training loop in PyTorch for any deep learning model consits of:

• looping over the dataset many times (aka epochs),
• in each one a mini-batch of from the dataset is loaded (with possible application of a set of transformations for data augmentation)
• zeroing the grads in the optimizer
• performing a forward pass on the given mini-batch of data
• calculating the losses between the result of the forward pass and the actual targets
• using these loosses perform a backward pass to update the weights of the model
The 5-steps of a gradient descent optimization algorithm - source

In 5 lines this training loop in PyTorch looks like this:

Note if we don’t zero the gradients, then in the next iteration when we do a backward pass they will be added to the current gradients. This is because pytorch may use multiple sources to calculate the gradients and the way it combines them is throught a sum.

For some cases, one may want to do more to control the training loop. For instance, try different:

• regularization techniques
• hyperparameter schedules
• mixed precision training
• tracking metrics

For each case, you end up rewriting the basic loop and adding logic to accomodate these requirements. One way to enable endless possibilities to customize the training loop is to use Callbacks. A callback is very common design pattern in many programming languages, with a basic idea of registering a handler that will be invoked on a sepecific condition. A typical case, will be an handler for specificc errors that may be triggered when calling a remote service.

For the use in a training loop, the possible events that we may one have handlers for include when the training begins or ends, and epoch begins or ends, etc. Those handlers can return any useful information or flags that skip steps or stop the trainig.

The Callback interface may looks like this:

Now after adding calback on each life cycle of the training loop, the earlier training loop becomes:

A basic use of callbacks is to log losses and metrics (e.g. accuracy) on the training/validation datasets after each epoch. More advanced use of callbacks can be to actively act on the training by tweaking hyper parameters of the training loop (e.g. learning rates). Furthermore, every tweak can be written in its own callback examples. For instance:

learning rate scheduler

Over the curse of the training, adjusting the learning rate is a practical way to speedup with convergence of the weights to their optimal values and thus requiring less epochs (which has the benefit of avoiding overfitting). There are different ways to schedule learning rate adjustment, time-based decay, step decay and exponential decay. All of which can be implemented with callback, for instance before each mini-batch:

early stopping

Early stopping aims to let the model be trained as far as a target metric is improving (e.g. accuracy on validation set) and stop otherwise in order to avoid overfitting on the training dataset. Using a callback, we can decide wether to continue training after each epoch or not as follows:

parallel training

Use PyTorch support for multi-GPUs, example

Gradient clipping allows the use of a large learning rate ( $$lr=1$$ ), see discussion. It can be done by safely modifying Variable.grad.data in place after the backward pass had finished, see example.

The basic idea behind accumulating gradient is to sum (or avergage) the gradients of several consecutive backward passes (if they were not reset with model.zero_grad() or optimizer.zero_grad()). This can be straightfully implemented in handler for loss calculated event:

Conclusion

Callbacks are a very handy way to experiment techniques to traing larger model (with 100 millions parameters), larger batch sizes and bigger learning rate, but also to fight overfitting and make the model generalizable. A well-designed callback system is crucial and has many benefits:

• keep training loop as simple as possible
• keep each tweak independent
• easily mix and match, or perform ablation studies