# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the License); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # httpwww.apache.orglicensesLICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an AS IS BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Process Dataset.""" import abc import os import time from .utils.adapter import get_raw_samples, read_image class BaseDataset: """ Create dataset. Args: data_url (str): The path of data. usage (str): Whether to use train or eval (default='train'). Returns: Dataset. """ def __init__(self, data_url, usage): self.data_url = data_url self.usage = usage self.cur_index = 0 self.samples = [] _s_time = time.time() self._load_samples() _e_time = time.time() print(f"load samples success~, time cost = {_e_time - _s_time}") def __getitem__(self, item): sample = self.samples[item] return self._next_data(sample) def __len__(self): return len(self.samples) @staticmethod def _next_data(sample): image_path = sample[0] mask_image_path = sample[1] image = read_image(image_path) mask_image = read_image(mask_image_path) return [image, mask_image] @abc.abstractmethod def _load_samples(self): pass class HwVocRawDataset(BaseDataset): """ Create dataset with raw data. Args: data_url (str): The path of data. usage (str): Whether to use train or eval (default='train'). Returns: Dataset. """ def __init__(self, data_url, usage="train"): super().__init__(data_url, usage) def _load_samples(self): try: self.samples = get_raw_samples(os.path.join(self.data_url, self.usage)) except Exception as e: print("load HwVocRawDataset failed!!!") raise e