Source code for tensorbay.opendataset.DownsampledImagenet.loader
#!/usr/bin/env python3
#
# Copyright 2021 Graviti. Licensed under MIT License.
#
# pylint: disable=invalid-name
"""Dataloader of DownsampledImagenet dataset."""
import os
from tensorbay.dataset import Data, Dataset, Segment
from tensorbay.opendataset._utility import glob
DATASET_NAME = "DownsampledImagenet"
SEGMENT_NAMES = ["train_32x32", "train_64x64", "valid_32x32", "valid_64x64"]
[docs]def DownsampledImagenet(path: str) -> Dataset:
"""`Downsampled Imagenet <https://www.tensorflow.org/datasets\
/catalog/downsampled_imagenet>`_ dataset.
The file structure should be like::
<path>
valid_32x32/
<imagename>.png
...
valid_64x64/
<imagename>.png
...
train_32x32/
<imagename>.png
...
train_64x64/
<imagename>.png
...
Arguments:
path: The root directory of the dataset.
Returns:
Loaded :class:`~tensorbay.dataset.dataset.Dataset` instance.
"""
root_path = os.path.abspath(os.path.expanduser(path))
dataset = Dataset(DATASET_NAME)
for segment_name in SEGMENT_NAMES:
dataset.add_segment(_get_segment(segment_name, root_path))
return dataset
def _get_segment(path: str, segment_name: str) -> Segment:
segment = Segment(segment_name)
image_paths = glob(os.path.join(path, segment_name, "*.png"))
for image_path in image_paths:
segment.append(Data(image_path))
return segment