Source code for tensorbay.opendataset.SegTrack2.loader

#!/usr/bin/env python3
#
# Copytright 2021 Graviti. Licensed under MIT License.
#
# pylint: disable=invalid-name

"""Dataloader of SegTrack2 dataset."""

import os
from typing import Callable, Dict

import numpy as np

from tensorbay.dataset import Data, Dataset
from tensorbay.label import InstanceMask
from tensorbay.opendataset._utility import glob

try:
    from PIL import Image
except ModuleNotFoundError:
    from tensorbay.opendataset._utility.mocker import Image  # pylint:disable=ungrouped-imports

DATASET_NAME = "SegTrack2"

_SEGMENTS_INFO = {
    1: ("bird_of_paradise", "birdfall", "frog", "girl", "monkey", "parachute", "soldier", "worm"),
    2: ("bmx", "cheetah", "drift", "hummingbird", "monkeydog"),
    6: ("penguin",),
}
_FILENAME_REFORMATTERS = (
    lambda filename: filename,
    lambda filename: f"{os.path.splitext(filename)[0]}.png",
)
_MASK_GETTER = Callable[[str, str, str, int, Callable[[str], str]], InstanceMask]


[docs]def SegTrack2(path: str) -> Dataset: """`SegTrack2 <https://web.engr.oregonstate.edu/~lif/SegTrack2/dataset.html>`_ dataset. The file structure of SegTrack looks like:: <path> GroundTruth/ bird_of_paradise/ bird_of_paradise_00000.png ... bmx/ 1/ bmx_06668.png ... 2/ bmx_06668.png ... ... JPEGImages/ bird_of_paradise/ bird_of_paradise_00000.png ... ... Arguments: path: The root directory of the dataset. Returns: Loaded :class:`~tensorbay.dataset.dataset.Dataset` instance. """ root_path = os.path.join(os.path.abspath(os.path.expanduser(path))) image_dir = os.path.join(root_path, "JPEGImages") original_mask_dir = os.path.join(root_path, "GroundTruth") mask_dir = os.path.join(root_path, "Masks") dataset = Dataset(DATASET_NAME) dataset.notes.is_continuous = True dataset.load_catalog(os.path.join(os.path.dirname(__file__), "catalog.json")) for instance_num, segment_names in _SEGMENTS_INFO.items(): for segment_name in segment_names: segment = dataset.create_segment(segment_name) mask_subdir = os.path.join(mask_dir, segment_name) os.makedirs(mask_subdir, exist_ok=True) all_mask_subdirs = { "original_mask_subdir": os.path.join(original_mask_dir, segment_name), "mask_subdir": mask_subdir, } filename_reformatter = ( _FILENAME_REFORMATTERS[1] if segment_name in {"penguin", "monkeydog"} else _FILENAME_REFORMATTERS[0] ) mask_getter: _MASK_GETTER = ( _get_cheetah_instance_mask if segment_name == "cheetah" else _get_instance_mask ) for image_path in glob(os.path.join(image_dir, segment_name, "*")): segment.append( _get_data( image_path, all_mask_subdirs, instance_num, filename_reformatter, mask_getter, ) ) return dataset
def _get_data( image_path: str, all_mask_subdirs: Dict[str, str], instance_num: int, filename_reformatter: Callable[[str], str], mask_getter: _MASK_GETTER, ) -> Data: data = Data(image_path) data.label.instance_mask = mask_getter( image_path, all_mask_subdirs["mask_subdir"], all_mask_subdirs["original_mask_subdir"], instance_num, filename_reformatter, ) return data def _get_instance_mask( image_path: str, mask_subdir: str, original_mask_subdir: str, instance_num: int, filename_reformatter: Callable[[str], str], ) -> InstanceMask: filename = filename_reformatter(os.path.basename(image_path)) mask_path = os.path.join(mask_subdir, f"{os.path.splitext(filename)[0]}.png") if instance_num == 1: mask = _get_reformatted_mask(os.path.join(original_mask_subdir, filename)) else: mask = _get_reformatted_mask(os.path.join(original_mask_subdir, "1", filename)) for instance_id in range(2, instance_num + 1): alter_mask = np.array( Image.open(os.path.join(original_mask_subdir, str(instance_id), filename)), )[:, :, 0] mask[alter_mask == 255] = instance_id Image.fromarray(mask).save(mask_path) return InstanceMask(mask_path) def _get_cheetah_instance_mask( image_path: str, mask_subdir: str, original_mask_subdir: str, _: int, __: Callable[[str], str] ) -> InstanceMask: filename = os.path.basename(image_path) new_filename = f"{os.path.splitext(filename)[0]}.png" mask_path = os.path.join(mask_subdir, new_filename) mask = _get_reformatted_mask(os.path.join(original_mask_subdir, "1", filename)) alter_mask = np.array( Image.open(os.path.join(original_mask_subdir, "2", new_filename)), )[:, :, 0] mask[alter_mask == 255] = 2 Image.fromarray(mask).save(mask_path) return InstanceMask(mask_path) def _get_reformatted_mask(original_mask_path: str) -> np.ndarray: mask = np.array(Image.open(original_mask_path))[:, :, 0] # reformat mask # from {background: 0, overlap: 1~254, target: 255} # to {background: 0, target: 1, overlap: 255} overlap = np.logical_and(mask > 0, mask < 255) mask[mask == 255] = 1 mask[overlap] = 255 return mask