TensorFlow provides an experimental get_memory_info API that returns the current GPU memory consumption.

We can use this API in a custom TF Callback to track GPU memory usage at peak during training as follows:

class GPUMemoryCallback(tf.keras.callbacks.Callback):
    def __init__(self, target_batches, print_stats=False, **kwargs):
        """
        target_batches: A list of batch indices at which to record memory usage.
        print_stats: A boolean flag indicating whether to print memory usage statistics.
        """
        super().__init__(**kwargs)
        self.target_batches = target_batches
        self.print_stats = print_stats

        self.memory_usage = []
        self.labels = []

    def _compute_memory_usage(self):
        memory_stats = tf.config.experimental.get_memory_info("GPU:0")
        # Convert bytes to GB and store in list.
        peak_usage = round(memory_stats["peak"] / (2**30), 3)
        self.memory_usage.append(peak_usage)

    def on_epoch_begin(self, epoch, logs=None):
        self._compute_memory_usage()
        self.labels.append(f"epoch {epoch} start")

    def on_train_batch_begin(self, batch, logs=None):
        if batch in self.target_batches:
            self._compute_memory_usage()
            self.labels.append(f"batch {batch}")

    def on_epoch_end(self, epoch, logs=None):
        self._compute_memory_usage()
        self.labels.append(f"epoch {epoch} end")

This callback uses the TensorFlow function tf.config.experimental.get_memory_info("GPU:0") to retrieve memory usage statistics for the GPU. It will record memory usage at the start of each epoch and at each batch index specified in target_batches. The recorded memory usage values, as well as the corresponding labels, are stored in the state of the callback.

Note: For simplicity we are assing there is a single GPU, GPU:0.

Here is an example show how to create an instance of such a callback to track consumption at various batches:

gpu_memory_callback = GPUMemoryCallback(
    target_batches=[5, 10, 25, 50, 100, 150, 200, 300, 400, 500],
    print_stats=True,
)

Once the callback instance is created we can simply pass it to model.fit so it gets called during training to track GPU consumption

model.compile(optimizer=optimizer, loss=loss, weighted_metrics=["accuracy"])

model.fit(train_ds, epochs=EPOCHS, callbacks=[gpu_memory_callback])

Once the model training finishes, we can access the consumption history as follows

memory_usage = gpu_memory_callback.memory_usage

Then we can simply plot it with matplotlib

plt.bar(memory_usage)

It is important to reset the peak memory usage to current memory usage before starting the training to make sure un-used memory is released and will not be accounted for in our callback.

tf.config.experimental.reset_memory_stats("GPU:0")

One good use case for tracking GPU consumption is to be able to compare two (or more) models training based on their GPU memory consumption. For instance, comparing a distilled version of a bigger model.

The workflow could be like this

gpu_memory_callback_1 = GPUMemoryCallback(...)
model_1.fit(train_ds, epochs=EPOCHS, callbacks=[gpu_memory_callback_1])

tf.config.experimental.reset_memory_stats("GPU:0")

gpu_memory_callback_2 = GPUMemoryCallback(...)
model_2.fit(train_ds, epochs=EPOCHS, callbacks=[gpu_memory_callback_2])

model_memory_usage_1 = gpu_memory_callback_1.memory_usage
model_memory_usage_2 = gpu_memory_callback_2.memory_usage

Then after training is done, we plot both consumptions to visually compare them:

plt.bar(
    ["Model 1", "Model 2"],
    [max(model_memory_usage_1), max(model_memory_usage_2)],
    color=["red", "blue"],
)

plt.xlabel("Time")
plt.ylabel("GPU Memory Usage (in GB)")

plt.title("GPU Memory Usage Comparison")
plt.legend()
plt.show()