Understanding TFRecords in Tensorflow 2.x
I used to feed my Tensorflow (tf) models with Numpy (np) ndarrays
. Although quite convenient, this approach could not be used for large datasets without complications, as the entirety (or parts) of the dataset should be loaded into the system memory. To solve these shortcomings, Tensorflow offers the TFRecord
class, a disk-based streaming solution that uses Protobufs
. Long story short, to use TFRecords in our projects, we should first convert the dataset into TFRecords and store them on disk. Then, before launching the training process, load them back from the disk and set the configurations such as batch size, epoch, and pre-fetch. To have a clear overview, these are the steps required to use TFRecords:
- Convert dataset to TFRecords:
- Iterate over the raw dataset and convert each pair of an example (image or anything else) and its label to
tf.train.Example
. - Serialize each
tf.train.Example
and write it into the opened TFRecord file withTFRecord.write()
.
- Iterate over the raw dataset and convert each pair of an example (image or anything else) and its label to
- Use the TFRecord that is created:
- Open the file with
tf.data.TFRecordDataset()
. - Parse the serialized entries of the dataset in a single pass with
dataset.map()
. - Set the bach size and epoch with
dataset.batch()
anddataset.repeat()
. - Feed the dataset (the configured and parsed TFRecord instance) directly to the model.
- Open the file with
Notes
- TFRecord is graph-based, meaning that to access the value of an entry, the returned tensors should be evaluated. See the example below:
tmp = dataset.take(1) # a new subset of the dataset that holds only one entry of the batch size. np_val = tf.keras.backend.eval( list(tmp.as_numpy_iterator()) )
- TFRecord is a streaming solution, meaning that random access to the entries is not possible.
- A configured TFRecord to a particular batch size (
B
) holdsB
data entries concatenated into a single tensor. - TFRecord compression is supported (using
GZIP
orZLIB
):options = tf.io.TFRecordOptions(compression_type='ZLIB') with tf.io.TFRecordWriter(filename, options=options) as tfr_train: pass
Examples
TODO.