From 8d785cff5ca0283097a96e20f5a41c85da32bf40 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Thu, 6 Aug 2020 16:29:05 +0800 Subject: [PATCH] update datasets --- dygraph/datasets/dataset.py | 39 ++++++++++++------------------ dygraph/datasets/optic_disc_seg.py | 16 ++++-------- dygraph/datasets/voc.py | 32 ++++++++++-------------- dygraph/infer.py | 9 +------ dygraph/train.py | 2 +- dygraph/val.py | 2 +- 6 files changed, 37 insertions(+), 63 deletions(-) diff --git a/dygraph/datasets/dataset.py b/dygraph/datasets/dataset.py index a7ad9aa8..316caa9f 100644 --- a/dygraph/datasets/dataset.py +++ b/dygraph/datasets/dataset.py @@ -25,8 +25,7 @@ class Dataset(fluid.io.Dataset): Args: data_dir: The dataset directory. num_classes: Number of classes. - image_set: Which part of dataset to use. Generally, image_set is of ('train', 'val', 'test'). Default: 'train'. - mode: Dataset usage. it is one of ('train', 'eva', 'test'). Default: 'train'. + mode: which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'. train_list: The train dataset file. When image_set is 'train', train_list is necessary. The contents of train_list file are as follow: image1.jpg ground_truth1.png @@ -46,7 +45,6 @@ class Dataset(fluid.io.Dataset): def __init__(self, data_dir, num_classes, - image_set='train', mode='train', train_list=None, val_list=None, @@ -59,21 +57,16 @@ class Dataset(fluid.io.Dataset): self.mode = mode self.num_classes = num_classes - if image_set.lower() not in ['train', 'val', 'test']: + if mode.lower() not in ['train', 'val', 'test']: raise Exception( - "image_set should be one of ('train', 'val', 'test'), but got {}." - .format(image_set)) - - if mode.lower() not in ['train', 'eval', 'test']: - raise Exception( - "mode should be 'train', 'eval' or 'test', but got {}.".format( + "mode should be 'train', 'val' or 'test', but got {}.".format( mode)) if self.transforms is None: raise Exception("transforms is necessary, but it is None.") self.data_dir = data_dir - if image_set == 'train': + if mode == 'train': if train_list is None: raise Exception( 'When mode is "train", train_list is necessary, but it is None.' @@ -83,10 +76,10 @@ class Dataset(fluid.io.Dataset): 'train_list is not found: {}'.format(train_list)) else: file_list = train_list - elif image_set == 'eval': + elif mode == 'val': if val_list is None: raise Exception( - 'When mode is "eval", val_list is necessary, but it is None.' + 'When mode is "val", val_list is necessary, but it is None.' ) elif not os.path.exists(val_list): raise Exception('val_list is not found: {}'.format(val_list)) @@ -106,9 +99,9 @@ class Dataset(fluid.io.Dataset): for line in f: items = line.strip().split(separator) if len(items) != 2: - if mode == 'train' or mode == 'eval': + if mode == 'train' or mode == 'val': raise Exception( - "File list format incorrect! It should be" + "File list format incorrect! In training or evaluation task it should be" " image_name{}label_name\\n".format(separator)) image_path = os.path.join(self.data_dir, items[0]) grt_path = None @@ -119,19 +112,19 @@ class Dataset(fluid.io.Dataset): def __getitem__(self, idx): image_path, grt_path = self.file_list[idx] - if self.mode == 'train': - im, im_info, label = self.transforms(im=image_path, label=grt_path) - return im, label - elif self.mode == 'eval': + if self.mode == 'test': + im, im_info, _ = self.transforms(im=image_path) + im = im[np.newaxis, ...] + return im, im_info, image_path + elif self.mode == 'val': im, im_info, _ = self.transforms(im=image_path) im = im[np.newaxis, ...] label = np.asarray(Image.open(grt_path)) label = label[np.newaxis, np.newaxis, :, :] return im, im_info, label - if self.mode == 'test': - im, im_info, _ = self.transforms(im=image_path) - im = im[np.newaxis, ...] - return im, im_info, image_path + else: + im, im_info, label = self.transforms(im=image_path, label=grt_path) + return im, label def __len__(self): return len(self.file_list) diff --git a/dygraph/datasets/optic_disc_seg.py b/dygraph/datasets/optic_disc_seg.py index 8608056f..e1b88176 100644 --- a/dygraph/datasets/optic_disc_seg.py +++ b/dygraph/datasets/optic_disc_seg.py @@ -25,7 +25,6 @@ class OpticDiscSeg(Dataset): def __init__(self, data_dir=None, transforms=None, - image_set='train', mode='train', download=True): self.data_dir = data_dir @@ -34,14 +33,9 @@ class OpticDiscSeg(Dataset): self.mode = mode self.num_classes = 2 - if image_set.lower() not in ['train', 'val', 'test']: + if mode.lower() not in ['train', 'val', 'test']: raise Exception( - "image_set should be one of ('train', 'val', 'test'), but got {}." - .format(image_set)) - - if mode.lower() not in ['train', 'eval', 'test']: - raise Exception( - "mode should be 'train', 'eval' or 'test', but got {}.".format( + "mode should be 'train', 'val' or 'test', but got {}.".format( mode)) if self.transforms is None: @@ -53,9 +47,9 @@ class OpticDiscSeg(Dataset): self.data_dir = download_file_and_uncompress( url=URL, savepath=DATA_HOME, extrapath=DATA_HOME) - if image_set == 'train': + if mode == 'train': file_list = os.path.join(self.data_dir, 'train_list.txt') - elif image_set == 'val': + elif mode == 'val': file_list = os.path.join(self.data_dir, 'val_list.txt') else: file_list = os.path.join(self.data_dir, 'test_list.txt') @@ -64,7 +58,7 @@ class OpticDiscSeg(Dataset): for line in f: items = line.strip().split() if len(items) != 2: - if mode == 'train' or mode == 'eval': + if mode == 'train' or mode == 'val': raise Exception( "File list format incorrect! It should be" " image_name label_name\\n") diff --git a/dygraph/datasets/voc.py b/dygraph/datasets/voc.py index f0614091..3527b6b5 100644 --- a/dygraph/datasets/voc.py +++ b/dygraph/datasets/voc.py @@ -25,15 +25,13 @@ class PascalVOC(Dataset): please run the voc_augment.py in tools. Args: data_dir: The dataset directory. - image_set: Which part of dataset to use. Generally, image_set is of ('train', 'val', 'trainval', 'trainaug). Default: 'train'. - mode: Dataset usage. it is one of ('train', 'eva', 'test'). Default: 'train'. + mode: Which part of dataset to use.. it is one of ('train', 'val', 'test'). Default: 'train'. transforms: Transforms for image. download: Whether to download dataset if data_dir is None. """ def __init__(self, data_dir=None, - image_set='train', mode='train', transforms=None, download=True): @@ -43,22 +41,17 @@ class PascalVOC(Dataset): self.file_list = list() self.num_classes = 21 - if image_set.lower() not in ['train', 'val', 'trainval', 'trainaug']: + if mode.lower() not in ['train', 'trainval', 'trainaug', 'val']: raise Exception( - "image_set should be one of ('train', 'val', 'trainval', 'trainaug'), but got {}." - .format(image_set)) - - if mode.lower() not in ['train', 'eval', 'test']: - raise Exception( - "mode should be 'train', 'eval' or 'test', but got {}.".format( - mode)) + "mode should be one of ('train', 'trainval', 'trainaug', 'val') in PascalVOC dataset, but got {}." + .format(mode)) if self.transforms is None: raise Exception("transforms is necessary, but it is None.") if self.data_dir is None: if not download: - raise Exception("data_file not set and auto download disabled.") + raise Exception("data_dir not set and auto download disabled.") self.data_dir = download_file_and_uncompress( url=URL, savepath=DATA_HOME, @@ -68,19 +61,19 @@ class PascalVOC(Dataset): image_set_dir = os.path.join(self.data_dir, 'VOC2012', 'ImageSets', 'Segmentation') - if image_set == 'train': + if mode == 'train': file_list = os.path.join(image_set_dir, 'train.txt') - elif image_set == 'val': + elif mode == 'val': file_list = os.path.join(image_set_dir, 'val.txt') - elif image_set == 'trainval': + elif mode == 'trainval': file_list = os.path.join(image_set_dir, 'trainval.txt') - elif image_set == 'trainaug': + elif mode == 'trainaug': file_list = os.path.join(image_set_dir, 'train.txt') file_list_aug = os.path.join(image_set_dir, 'aug.txt') if not os.path.exists(file_list_aug): raise Exception( - "When image_set is 'trainaug', Pascal Voc dataset should be augmented, " + "When mode is 'trainaug', Pascal Voc dataset should be augmented, " "Please make sure voc_augment.py has been properly run when using this mode." ) @@ -95,10 +88,11 @@ class PascalVOC(Dataset): image_path = os.path.join(img_dir, ''.join([line, '.jpg'])) grt_path = os.path.join(grt_dir, ''.join([line, '.png'])) self.file_list.append([image_path, grt_path]) - if image_set == 'trainaug': + if mode == 'trainaug': with open(file_list_aug, 'r') as f: for line in f: line = line.strip() image_path = os.path.join(img_dir, ''.join([line, '.jpg'])) - grt_path = os.path.join(grt_dir, ''.join([line, '.png'])) + grt_path = os.path.join(grt_dir_aug, ''.join([line, + '.png'])) self.file_list.append([image_path, grt_path]) diff --git a/dygraph/infer.py b/dygraph/infer.py index 29d773bf..6cb8f6d6 100644 --- a/dygraph/infer.py +++ b/dygraph/infer.py @@ -13,20 +13,13 @@ # limitations under the License. import argparse -import os -from paddle.fluid.dygraph.base import to_variable -import numpy as np import paddle.fluid as fluid from paddle.fluid.dygraph.parallel import ParallelEnv -import cv2 -import tqdm from datasets import DATASETS import transforms as T from models import MODELS -import utils -import utils.logging as logging from utils import get_environ_info from core import infer @@ -43,7 +36,7 @@ def parse_args(): type=str, default='UNet') - # params of dataset + # params of infer parser.add_argument( '--dataset', dest='dataset', diff --git a/dygraph/train.py b/dygraph/train.py index decf3f72..cf3ad5b6 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -153,7 +153,7 @@ def main(args): eval_transforms = T.Compose( [T.Resize(args.input_size), T.Normalize()]) - eval_dataset = dataset(transforms=eval_transforms, mode='eval') + eval_dataset = dataset(transforms=eval_transforms, mode='val') if args.model_name not in MODELS: raise Exception( diff --git a/dygraph/val.py b/dygraph/val.py index 6d3b2de1..a453bd84 100644 --- a/dygraph/val.py +++ b/dygraph/val.py @@ -87,7 +87,7 @@ def main(args): with fluid.dygraph.guard(places): eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) - eval_dataset = dataset(transforms=eval_transforms, mode='eval') + eval_dataset = dataset(transforms=eval_transforms, mode='val') if args.model_name not in MODELS: raise Exception( -- GitLab