From 60f24061c7d47a09a886ec1e3e5f351a43f71418 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Fri, 7 Aug 2020 11:03:22 +0800 Subject: [PATCH] add ade20k dataset --- dygraph/datasets/__init__.py | 4 +- dygraph/datasets/ade.py | 72 ++++++++++++++++++++++++++++++++++++ dygraph/datasets/voc.py | 1 - 3 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 dygraph/datasets/ade.py diff --git a/dygraph/datasets/__init__.py b/dygraph/datasets/__init__.py index e78dd9f9..37d8da36 100644 --- a/dygraph/datasets/__init__.py +++ b/dygraph/datasets/__init__.py @@ -16,9 +16,11 @@ from .dataset import Dataset from .optic_disc_seg import OpticDiscSeg from .cityscapes import Cityscapes from .voc import PascalVOC +from .ade import ADE20K DATASETS = { "OpticDiscSeg": OpticDiscSeg, "Cityscapes": Cityscapes, - "PascalVOC": PascalVOC + "PascalVOC": PascalVOC, + "ADE20K": ADE20K } diff --git a/dygraph/datasets/ade.py b/dygraph/datasets/ade.py new file mode 100644 index 00000000..e8a19256 --- /dev/null +++ b/dygraph/datasets/ade.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from .dataset import Dataset +from utils.download import download_file_and_uncompress + +DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') +URL = "https://paddleseg.bj.bcebos.com/dataset/ADEChallengeData2016.zip" + + +class ADE20K(Dataset): + """ADE20K dataset `http://sceneparsing.csail.mit.edu/`. + Args: + data_dir: The dataset directory. + mode: Which part of dataset to use.. it is one of ('train', 'val'). Default: 'train'. + transforms: Transforms for image. + download: Whether to download dataset if data_dir is None. + """ + + def __init__(self, + data_dir=None, + mode='train', + transforms=None, + download=True): + self.data_dir = data_dir + self.transforms = transforms + self.mode = mode + self.file_list = list() + self.num_classes = 21 + + if mode.lower() not in ['train', 'val']: + raise Exception( + "mode should be one of ('train', 'val') in PascalVOC dataset, but got {}." + .format(mode)) + + if self.transforms is None: + raise Exception("transforms is necessary, but it is None.") + + if self.data_dir is None: + if not download: + raise Exception("data_dir not set and auto download disabled.") + self.data_dir = download_file_and_uncompress( + url=URL, + savepath=DATA_HOME, + extrapath=DATA_HOME, + extraname='ADEChallengeData2016') + + if mode == 'train': + img_dir = os.path.join(self.data_dir, 'images/training') + grt_dir = os.path.join(self.data_dir, 'annotations/training') + elif mode == 'val': + img_dir = os.path.join(self.data_dir, 'images/validation') + grt_dir = os.path.join(self.data_dir, 'annotations/validation') + img_files = os.listdir(img_dir) + grt_files = [i.replace('.jpg', '.png') for i in img_files] + for i in range(len(img_files)): + img_path = os.path.join(img_dir, img_files[i]) + grt_path = os.path.join(grt_dir, grt_files[i]) + self.file_list.append([img_path, grt_path]) diff --git a/dygraph/datasets/voc.py b/dygraph/datasets/voc.py index 3527b6b5..56ece84a 100644 --- a/dygraph/datasets/voc.py +++ b/dygraph/datasets/voc.py @@ -57,7 +57,6 @@ class PascalVOC(Dataset): savepath=DATA_HOME, extrapath=DATA_HOME, extraname='VOCdevkit') - print(self.data_dir) image_set_dir = os.path.join(self.data_dir, 'VOC2012', 'ImageSets', 'Segmentation') -- GitLab