Source code for tensorbay.opendataset.PASCALContext.loader

#!/usr/bin/env python3
#
# Copyright 2021 Graviti. Licensed under MIT License.
#
# pylint: disable=invalid-name

"""Dataloader of PASCALContext dataset."""

import os

from tensorbay.dataset import Data, Dataset
from tensorbay.exception import ModuleImportError
from tensorbay.label import SemanticMask
from tensorbay.opendataset._utility import glob

DATASET_NAME = "PASCALContext"


[docs]def PASCALContext(mask_path: str, image_path: str) -> Dataset: """`PASCALContext <https://cs.stanford.edu/~roozbeh/pascal-context/>`_ dataset. The file structure should be like:: <mask_path> <image_name>.png ... <image_path> <image_name>.jpg ... Arguments: mask_path: The root directory of the dataset mask. image_path: The root directory of the dataset image. Returns: Loaded :class: `~tensorbay.dataset.dataset.Dataset` instance. """ root_mask_path = os.path.abspath(os.path.expanduser(mask_path)) root_image_path = os.path.abspath(os.path.expanduser(image_path)) dataset = Dataset(DATASET_NAME) dataset.load_catalog(os.path.join(os.path.dirname(__file__), "catalog.json")) segment = dataset.create_segment("trainval") for mask_filename in glob(os.path.join(root_mask_path, "*.png")): stem = os.path.splitext(os.path.basename(mask_filename))[0] data = Data(os.path.join(root_image_path, f"{stem}.jpg")) data.label.semantic_mask = SemanticMask(mask_filename) segment.append(data) return dataset
def convert_mask(path: str, mask_path: str) -> None: """Convert the mat format labels of the PASCALContext dataset to masks. The file structure of the input path should be like:: <path> <trainval> <image_name>.mat ... Arguments: path: The root directory of the dataset. mask_path: The root directory where to save the masks. Raises: ModuleImportError: When the module "scipy" or "Pillow" can not be found. """ try: from PIL import Image # pylint: disable=import-outside-toplevel from scipy.io import loadmat # pylint: disable=import-outside-toplevel except ModuleNotFoundError as error: module_name = error.name package_name = "Pillow" if module_name == "PIL" else None raise ModuleImportError(module_name=module_name, package_name=package_name) from error root_path = os.path.abspath(os.path.expanduser(path)) root_mask_path = os.path.abspath(os.path.expanduser(mask_path)) for mat_path in glob(os.path.join(root_path, "trainval", "*.mat")): stem = os.path.splitext(os.path.basename(mat_path))[0] mat = loadmat(mat_path) image = Image.fromarray(mat["LabelMap"]) image.save(os.path.join(root_mask_path, f"{stem}.png"))