Understanding TFRecords in Tensorflow 2.x

I used to feed my Tensorflow (tf) models with Numpy (np) ndarrays. Although quite convenient, this approach could not be used for large datasets without complications, as the entirety (or parts) of the dataset should be loaded into the system memory. To solve these shortcomings, Tensorflow offers the TFRecord class, a disk-based streaming solution that uses Protobufs. Long story short, to use TFRecords in our projects, we should first convert the dataset into TFRecords and store them on disk. Then, before launching the training process, load them back from the disk and set the configurations such as batch size, epoch, and pre-fetch. To have a clear overview, these are the steps required to use TFRecords:

Notes

  1. TFRecord is graph-based, meaning that to access the value of an entry, the returned tensors should be evaluated. See the example below:
    tmp = dataset.take(1) # a new subset of the dataset that holds only one entry of the batch size.
    np_val = tf.keras.backend.eval( list(tmp.as_numpy_iterator()) )
    
  2. TFRecord is a streaming solution, meaning that random access to the entries is not possible.
  3. A configured TFRecord to a particular batch size (B) holds B data entries concatenated into a single tensor.
  4. TFRecord compression is supported (using GZIP or ZLIB):
    options = tf.io.TFRecordOptions(compression_type='ZLIB')
    with tf.io.TFRecordWriter(filename, options=options) as tfr_train:
      pass
    

Examples

TODO.

  1. TF Documentation
  2. Feeding-TensorFlow-from-drive-MNIST-Example
  3. Mnist-Tfrecord