From 86bc060d07f15809b906bb06ffa6481dd238448e Mon Sep 17 00:00:00 2001 From: root Date: Mon, 6 Jul 2020 11:39:57 +0000 Subject: [PATCH] add dataset base class --- dygraph/datasets/__init__.py | 1 + dygraph/datasets/cityscapes.py | 16 +---- dygraph/datasets/dataset.py | 97 ++++++++++++++++++++++++++++++ dygraph/datasets/optic_disc_seg.py | 16 +---- 4 files changed, 100 insertions(+), 30 deletions(-) create mode 100644 dygraph/datasets/dataset.py diff --git a/dygraph/datasets/__init__.py b/dygraph/datasets/__init__.py index 072a82f7..9a52eccf 100644 --- a/dygraph/datasets/__init__.py +++ b/dygraph/datasets/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .dataset import Dataset from .optic_disc_seg import OpticDiscSeg from .cityscapes import Cityscapes diff --git a/dygraph/datasets/cityscapes.py b/dygraph/datasets/cityscapes.py index 21f96782..006148ff 100644 --- a/dygraph/datasets/cityscapes.py +++ b/dygraph/datasets/cityscapes.py @@ -14,8 +14,7 @@ import os -from paddle.fluid.io import Dataset - +from dataset import Dataset from utils.download import download_file_and_uncompress DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') @@ -70,16 +69,3 @@ class Cityscapes(Dataset): image_path = os.path.join(self.data_dir, items[0]) grt_path = os.path.join(self.data_dir, items[1]) self.file_list.append([image_path, grt_path]) - - def __getitem__(self, idx): - image_path, grt_path = self.file_list[idx] - im, im_info, label = self.transforms(im=image_path, label=grt_path) - if self.mode == 'train': - return im, label - elif self.mode == 'eval': - return im, label - if self.mode == 'test': - return im, im_info, image_path - - def __len__(self): - return len(self.file_list) diff --git a/dygraph/datasets/dataset.py b/dygraph/datasets/dataset.py new file mode 100644 index 00000000..86f67d43 --- /dev/null +++ b/dygraph/datasets/dataset.py @@ -0,0 +1,97 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from paddle.fluid.io import Dataset + + +class Dataset(Dataset): + def __init__(self, + data_dir, + num_classes, + train_list=None, + val_list=None, + test_list=None, + separator=' ', + transforms=None, + mode='train'): + self.data_dir = data_dir + self.transforms = transforms + self.file_list = list() + self.mode = mode + self.num_classes = num_classes + + if mode.lower() not in ['train', 'eval', 'test']: + raise Exception( + "mode should be 'train', 'eval' or 'test', but got {}.".format( + mode)) + + if self.transforms is None: + raise Exception("transform is necessary, but it is None.") + + self.data_dir = data_dir + if mode == 'train': + if train_list is None: + raise Exception( + 'When mode is "train", train_list is need, but it is None.') + elif not os.path.exists(train_list): + raise Exception( + 'train_list is not found: {}'.format(train_list)) + else: + file_list = train_list + elif mode == 'eval': + if val_list is None: + raise Exception( + 'When mode is "eval", val_list is need, but it is None.') + elif not os.path.exists(val_list): + raise Exception('val_list is not found: {}'.format(val_list)) + else: + file_list = val_list + else: + if test_list is None: + raise Exception( + 'When mode is "test", test_list is need, but it is None.') + elif not os.path.exists(test_list): + raise Exception('test_list is not found: {}'.format(test_list)) + else: + file_list = test_list + + with open(file_list, 'r') as f: + for line in f: + items = line.strip().split(separator) + if len(items) != 2: + if mode == 'train' or mode == 'eval': + raise Exception( + "File list format incorrect! It should be" + " image_name{}label_name\\n".format(separator)) + image_path = os.path.join(self.data_dir, items[0]) + grt_path = None + else: + image_path = os.path.join(self.data_dir, items[0]) + grt_path = os.path.join(self.data_dir, items[1]) + self.file_list.append([image_path, grt_path]) + + def __getitem__(self, idx): + image_path, grt_path = self.file_list[idx] + im, im_info, label = self.transforms(im=image_path, label=grt_path) + if self.mode == 'train': + return im, label + elif self.mode == 'eval': + return im, label + if self.mode == 'test': + return im, im_info, image_path + + def __len__(self): + return len(self.file_list) diff --git a/dygraph/datasets/optic_disc_seg.py b/dygraph/datasets/optic_disc_seg.py index 0a321915..2dade1c7 100644 --- a/dygraph/datasets/optic_disc_seg.py +++ b/dygraph/datasets/optic_disc_seg.py @@ -14,8 +14,7 @@ import os -from paddle.fluid.io import Dataset - +from dataset import Dataset from utils.download import download_file_and_uncompress DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') @@ -70,16 +69,3 @@ class OpticDiscSeg(Dataset): image_path = os.path.join(self.data_dir, items[0]) grt_path = os.path.join(self.data_dir, items[1]) self.file_list.append([image_path, grt_path]) - - def __getitem__(self, idx): - image_path, grt_path = self.file_list[idx] - im, im_info, label = self.transforms(im=image_path, label=grt_path) - if self.mode == 'train': - return im, label - elif self.mode == 'eval': - return im, label - if self.mode == 'test': - return im, im_info, image_path - - def __len__(self): - return len(self.file_list) -- GitLab