Source code for tensorbay.opendataset.OxfordIIITPet.loader

#!/usr/bin/env python3
#
# Copyright 2021 Graviti. Licensed under MIT License.
#
# pylint: disable=invalid-name

"""Dataloader of OxfordIIITPet dataset."""

import os
from typing import Any, List

from tensorbay.dataset import Data, Dataset
from tensorbay.label import Classification, LabeledBox2D, SemanticMask
from tensorbay.opendataset._utility.glob import glob

try:
    import xmltodict
except ModuleNotFoundError:
    from tensorbay.opendataset._utility.mocker import xmltodict  # pylint:disable=ungrouped-imports

DATASET_NAME = "OxfordIIITPet"


[docs]def OxfordIIITPet(path: str) -> Dataset: """`OxfordIIITPet <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_ dataset. The file structure should be like:: <path> annotations/ trimaps/ Bombay_113.png Bombay_114.png ... xmls/ Birman_174.xml Birman_175.xml ... list.txt test.txt trainval.txt README images/ Bombay_117.jpg Bombay_118.jpg ... Arguments: path: The root directory of the dataset. Returns: Loaded :class: `~tensorbay.dataset.dataset.Dataset` instance. """ root_path = os.path.abspath(os.path.expanduser(path)) dataset = Dataset(DATASET_NAME) dataset.load_catalog(os.path.join(os.path.dirname(__file__), "catalog.json")) trainval_segment = dataset.create_segment("trainval") test_segment = dataset.create_segment("test") annotation_path = os.path.join(root_path, "annotations") for image_path in glob(os.path.join(root_path, "images", "*.jpg")): stem = os.path.splitext(os.path.basename(image_path))[0] name = "Cat" if stem.istitle() else "Dog" category, num = stem.rsplit("_", 1) data = Data(image_path, target_remote_path=f"{category}_{num.zfill(3)}.jpg") label = data.label label.classification = Classification(category=f"{name}.{category}") label.semantic_mask = SemanticMask(os.path.join(annotation_path, "trimaps", f"{stem}.png")) xml_path = os.path.join(annotation_path, "xmls", f"{stem}.xml") if os.path.exists(xml_path): label.box2d = _get_box_label(xml_path) trainval_segment.append(data) else: test_segment.append(data) return dataset
def _get_box_label(file_path: str) -> List[LabeledBox2D]: with open(file_path, encoding="utf-8") as fp: labels: Any = xmltodict.parse(fp.read()) objects = labels["annotation"]["object"] if not isinstance(objects, list): objects = [objects] box2d = [] for obj in objects: bndbox = obj["bndbox"] box2d.append( LabeledBox2D( float(bndbox["xmin"]), float(bndbox["ymin"]), float(bndbox["xmax"]), float(bndbox["ymax"]), ) ) return box2d