
This topic describes how to integrate TensorBay dataset with TensorFlow Pipeline using the MNIST Dataset as an example.

The typical method to integrate TensorBay dataset with TensorFlow is to build a callable “Segment” class.

import numpy as np
import tensorflow as tf
from PIL import Image
from tensorflow.data import Dataset

from tensorbay import GAS
from tensorbay.dataset import Dataset as TensorBayDataset

class MNISTSegment:
    """class for wrapping a MNIST segment."""

    def __init__(self, gas, segment_name):
        self.dataset = TensorBayDataset("MNIST", gas)
        self.segment = self.dataset[segment_name]
        self.category_to_index = self.dataset.catalog.classification.get_category_to_index()

    def __call__(self):
        """Yield an image and its corresponding label.

            image_tensor: the tensorflow sensor of the image.
            category_tensor: the tensorflow sensor of the category.

        for data in self.segment:
            with data.open() as fp:
                image_tensor = tf.convert_to_tensor(
                    np.array(Image.open(fp)) / 255, dtype=tf.float32
            category = self.category_to_index[data.label.classification.category]
            category_tensor = tf.convert_to_tensor(category, dtype=tf.int32)
            yield image_tensor, category_tensor

Using the following code to create a TensorFlow dataset and run it:

# Please visit `https://gas.graviti.com/tensorbay/developer` to get the AccessKey.

dataset = Dataset.from_generator(
    MNISTSegment(GAS(ACCESS_KEY), "train"),
        tf.TensorSpec(shape=(28, 28), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int32),

for index, (image, label) in enumerate(dataset):
    print(f"{index}: {label}")