TFRecord is a lightweight format optimized for streaming large datasets. It supports any binary data, here is a basic example:

import tensorflow as tf

with tf.io.TFRecordWriter("sample.tfrecord") as w:
    w.write(b"Record A")
    w.write(b"Record B")

for record in tf.data.TFRecordDataset("sample.tfrecord"):
    print(record)

The output would look like the this

tf.Tensor(b'Record A', shape=(), dtype=string)
tf.Tensor(b'Record B', shape=(), dtype=string)

TFRecord files can contain records of type tf.Example where each column of the original data is stored as a feature.

Storing data as TFRecord and tf.Examples has the following advantages:

  • TFRecord relies on Protocol Buffers, which is a cross-platform serialization format and supported by many libraries for popular programming languages.
  • TFRecord is optimized for ingesting large amounts of data.
  • tf.Example is also the default data structure in the TensorFlow ecosystem.

Write TFRecords

Creates TFRecord from Structured Dataset

The following example creates a TFRecord for structured data where a feature corresponds to a colum in the original dataset:

# create a writer
tfrecord_writer = tf.io.TFRecordWriter("data.tfrecord")

# iterate over the data and create a tf.Example for each row
for row in data:
  # create a feature for each column in the row
  example = tf.train.Example(features=tf.train.Features(feature={
    "int_col": tf.train.Feature(int64_list=tf.train.Int64List(value=[row['int_col']])),
    "byte_col": tf.train.Feature(bytes_list=tf.train.BytesList(value=[row['byte_col']])),
    "float_col": tf.train.Feature(float_list=tf.train.FloatList(value=[row['float_col']])),
    ...
  }))
  # serialize example and write it
  tfrecord_writer.write(example.SerializeToString())

# close writer
tfrecord_writer.close()

Creates TFRecord from Image Dataset

The following example creates a TFRecord for image data where:

# create a writer
tfrecord_writer = tf.io.TFRecordWriter("data.tfrecord")

# iterate over images in directory
for name, label in zip(filenames, labels):
  img_path = os.path.join(base_path, name)
  # try read image file
  try:
    raw_file = tf.io.read_file(img_path)
  except FileNotFoundError:
    print("Couldn't read file  {}".format(img_path))
    continue
  # create an example with the image and label
  example = tf.train.Example(features=tf.train.Features(feature={
    'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[raw_file.numpy()])),
    'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
  }))
  # write example
  tfrecord_writer.write(example.SerializeToString())

# close writer
tfrecord_writer.close()

Read TFRecords

To use data stored in TFRecord files, we can use TensorFlow’s batched_features_dataset function from the tf.data API to load the examples in batches as follows:

# helper function to read tfrecords file
def tfrecord_reader_fn(filenames):
  return tf.data.TFRecordDataset(filenames, compression_type='GZIP')

# definition of each feature
features_spec = {
  "int_col": FixedLenFeature([], dtype=tf.int64, default_value=-1),
  "str_col": FixedLenFeature([], dtype=tf.string),
  "label": VarLenFeature(dtype=tf.string),
}

# name of the label column in the features_spec
label_key  = "label"

# list of files to read or regex pattern
file_pattern = ["001.tfrecord", "002.tfrecord", ...]

# number of examples in each batch
batch_size = 64

# dataset loaded from TFRecord files
train_ds = tf.data.experimental.make_batched_features_dataset(
  file_pattern = file_pattern,
  batch_size = batch_size,
  features = features_spec,
  reader = tfrecord_reader_fn,
  label_key = label_key
)

Now we can use the previously created data generator to train a model as follows:

model.fit(train_ds, ...)