Source code for tensorbay.opendataset.SVHN.loader

#!/usr/bin/env python3
#
# Copyright 2021 Graviti. Licensed under MIT License.
#
# pylint: disable=invalid-name

"""Dataloader of SVHN dataset."""

import os
from typing import Any

from tensorbay.dataset import Data, Dataset
from tensorbay.exception import ModuleImportError
from tensorbay.label import LabeledBox2D

_SEGMENTS = ("extra", "test", "train")

DATASET_NAME = "SVHN"


[docs]def SVHN(path: str) -> Dataset: """`SVHN <http://ufldl.stanford.edu/housenumbers>`_ dataset. The file structure should be like:: <path> Cropped/ extra_32x32.mat test_32x32.mat train_32x32.mat FullNumbers/ extra/ 116507.png 116508.png ... digitStruct.mat see_bboxes.m test/ train/ Arguments: path: The root directory of the dataset. Raises: ModuleImportError: When the module "h5py" can not be found. Returns: Loaded :class: `~tensorbay.dataset.dataset.Dataset` instance. """ try: from h5py import File # pylint: disable=import-outside-toplevel except ModuleNotFoundError as error: raise ModuleImportError(module_name=error.name) from error root_path = os.path.join(os.path.abspath(os.path.expanduser(path)), "FullNumbers") dataset = Dataset(DATASET_NAME) dataset.load_catalog(os.path.join(os.path.dirname(__file__), "catalog.json")) for segment_name in _SEGMENTS: segment = dataset.create_segment(segment_name) file_path = os.path.join(root_path, segment_name) mat = File(os.path.join(file_path, "digitStruct.mat")) names = mat["digitStruct"]["name"] bboxes = mat["digitStruct"]["bbox"] for name, bbox in zip(names, bboxes): segment.append(_get_data(mat, name, bbox, file_path)) return dataset
def _get_data(mat: Any, name: Any, bbox: Any, file_path: str) -> Data: image_path = "".join(chr(v[0]) for v in mat[name[0]]) data = Data(os.path.join(file_path, image_path), target_remote_path=image_path.zfill(10)) data.label.box2d = [] mat_bbox = mat[bbox[0]] labeled_box = ( {key: [value[0][0]] for key, value in mat_bbox.items()} if mat_bbox["label"].shape[0] == 1 else {key: [mat[value[0]][0][0] for value in values] for key, values in mat_bbox.items()} ) for x, y, w, h, e in zip( labeled_box["left"], labeled_box["top"], labeled_box["width"], labeled_box["height"], labeled_box["label"], ): data.label.box2d.append( LabeledBox2D.from_xywh(x, y, w, h, category="0" if e == 10 else str(int(e))) ) return data