Create TFRecord for your data
TFRecord is a lightweight format optimized for streaming large datasets. It supports any binary data, here is a basic example:
import tensorflow as tf
with tf.io.TFRecordWriter("sample.tfrecord") as w:
w.write(b"Record A")
w.write(b"Record B")
for record in tf.data.TFRecordDataset("sample.tfrecord"):
print(record)
The output would look like the this
tf.Tensor(b'Record A', shape=(), dtype=string)
tf.Tensor(b'Record B', shape=(), dtype=string)
TFRecord files can contain records of type tf.Example
where each column of the original data is stored as a feature.
Storing data as TFRecord and tf.Examples has the following advantages:
- TFRecord relies on Protocol Buffers, which is a cross-platform serialization format and supported by many libraries for popular programming languages.
- TFRecord is optimized for ingesting large amounts of data.
- tf.Example is also the default data structure in the TensorFlow ecosystem.
Write TFRecords
Creates TFRecord from Structured Dataset
The following example creates a TFRecord for structured data where a feature corresponds to a colum in the original dataset:
# create a writer
tfrecord_writer = tf.io.TFRecordWriter("data.tfrecord")
# iterate over the data and create a tf.Example for each row
for row in data:
# create a feature for each column in the row
example = tf.train.Example(features=tf.train.Features(feature={
"int_col": tf.train.Feature(int64_list=tf.train.Int64List(value=[row['int_col']])),
"byte_col": tf.train.Feature(bytes_list=tf.train.BytesList(value=[row['byte_col']])),
"float_col": tf.train.Feature(float_list=tf.train.FloatList(value=[row['float_col']])),
...
}))
# serialize example and write it
tfrecord_writer.write(example.SerializeToString())
# close writer
tfrecord_writer.close()
Creates TFRecord from Image Dataset
The following example creates a TFRecord for image data where:
# create a writer
tfrecord_writer = tf.io.TFRecordWriter("data.tfrecord")
# iterate over images in directory
for name, label in zip(filenames, labels):
img_path = os.path.join(base_path, name)
# try read image file
try:
raw_file = tf.io.read_file(img_path)
except FileNotFoundError:
print("Couldn't read file {}".format(img_path))
continue
# create an example with the image and label
example = tf.train.Example(features=tf.train.Features(feature={
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[raw_file.numpy()])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
# write example
tfrecord_writer.write(example.SerializeToString())
# close writer
tfrecord_writer.close()
Read TFRecords
To use data stored in TFRecord files, we can use TensorFlow’s batched_features_dataset function from the tf.data API to load the examples in batches as follows:
# helper function to read tfrecords file
def tfrecord_reader_fn(filenames):
return tf.data.TFRecordDataset(filenames, compression_type='GZIP')
# definition of each feature
features_spec = {
"int_col": FixedLenFeature([], dtype=tf.int64, default_value=-1),
"str_col": FixedLenFeature([], dtype=tf.string),
"label": VarLenFeature(dtype=tf.string),
}
# name of the label column in the features_spec
label_key = "label"
# list of files to read or regex pattern
file_pattern = ["001.tfrecord", "002.tfrecord", ...]
# number of examples in each batch
batch_size = 64
# dataset loaded from TFRecord files
train_ds = tf.data.experimental.make_batched_features_dataset(
file_pattern = file_pattern,
batch_size = batch_size,
features = features_spec,
reader = tfrecord_reader_fn,
label_key = label_key
)
Now we can use the previously created data generator to train a model as follows:
model.fit(train_ds, ...)