diff --git a/dygraph/infer.py b/dygraph/infer.py index 2efe6743be588576fc65425d8f036d42cff7e9a6..ad8c485dd3281a40a0d59b8af3bffd98ca8b987f 100644 --- a/dygraph/infer.py +++ b/dygraph/infer.py @@ -111,7 +111,7 @@ def infer(model, data_dir=None, test_list=None, model_dir=None, for file in tqdm.tqdm(files): file = file.strip() im_file = osp.join(data_dir, file) - im, im_info = transforms(im_file) + im, im_info, _ = transforms(im_file) im = np.expand_dims(im, axis=0) im = to_variable(im) @@ -140,17 +140,8 @@ def infer(model, data_dir=None, test_list=None, model_dir=None, cv2.imwrite(pred_saved_path, pred_im) -def arrange_transform(transforms, mode='train'): - arrange_transform = T.ArrangeSegmenter - if type(transforms.transforms[-1]).__name__.startswith('Arrange'): - transforms.transforms[-1] = arrange_transform(mode=mode) - else: - transforms.transforms.append(arrange_transform(mode=mode)) - - def main(args): test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) - arrange_transform(test_transforms, mode='test') if args.model_name == 'UNet': model = models.UNet(num_classes=args.num_classes) diff --git a/dygraph/train.py b/dygraph/train.py index 324393e008269b0a8dbe4f2824986117760f5c6c..d979576c5cace22d3ce384958caf13d8fafea47d 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -143,7 +143,7 @@ def train(model, for epoch in range(num_epochs): for step, data in enumerate(data_generator()): images = np.array([d[0] for d in data]) - labels = np.array([d[1] for d in data]).astype('int64') + labels = np.array([d[2] for d in data]).astype('int64') images = to_variable(images) labels = to_variable(labels) loss = model(images, labels, mode='train') @@ -175,21 +175,12 @@ def train(model, model.train() -def arrange_transform(transforms, mode='train'): - arrange_transform = T.ArrangeSegmenter - if type(transforms.transforms[-1]).__name__.startswith('Arrange'): - transforms.transforms[-1] = arrange_transform(mode=mode) - else: - transforms.transforms.append(arrange_transform(mode=mode)) - - def main(args): # Creat dataset reader train_transforms = T.Compose( [T.Resize(args.input_size), T.RandomHorizontalFlip(), T.Normalize()]) - arrange_transform(train_transforms, mode='train') train_dataset = Dataset( data_dir=args.data_dir, file_list=args.train_list, @@ -200,7 +191,6 @@ def main(args): shuffle=True) if args.val_list is not None: eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) - arrange_transform(eval_transforms, mode='eval') eval_dataset = Dataset( data_dir=args.data_dir, file_list=args.val_list, diff --git a/dygraph/transforms/transforms.py b/dygraph/transforms/transforms.py index fa793126b4b0971e6505c900f0b22769113d177f..1ff7eb1b47c222902ed6a76a9d31afbe35690417 100644 --- a/dygraph/transforms/transforms.py +++ b/dygraph/transforms/transforms.py @@ -74,7 +74,10 @@ class Compose: im_info = outputs[1] if len(outputs) == 3: label = outputs[2] - return outputs + im = permute(im) + if len(outputs) == 3: + label = label[np.newaxis, :, :] + return (im, im_info, label) class RandomHorizontalFlip: @@ -873,42 +876,3 @@ class RandomDistort: return (im, im_info) else: return (im, im_info, label) - - -class ArrangeSegmenter: - """获取训练/验证/预测所需的信息。 - - Args: - mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。 - - Raises: - ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内 - """ - - def __init__(self, mode): - if mode not in ['train', 'eval', 'test', 'quant']: - raise ValueError( - "mode should be defined as one of ['train', 'eval', 'test', 'quant']!" - ) - self.mode = mode - - def __call__(self, im, im_info, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当mode为'train'或'eval'时,返回的tuple为(im, label),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当mode为'test'时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;当mode为 - 'quant'时,返回的tuple为(im,),为图像np.ndarray数据。 - """ - im = permute(im) - if self.mode == 'train' or self.mode == 'eval': - label = label[np.newaxis, :, :] - return (im, label) - elif self.mode == 'test': - return (im, im_info) - else: - return (im, ) diff --git a/dygraph/val.py b/dygraph/val.py index e760ec2c4d4dbb152eaef841d40b3f7293496aaa..e00b5ad720b97ae8bc68fa21c45c6c9fc78cf2da 100644 --- a/dygraph/val.py +++ b/dygraph/val.py @@ -21,7 +21,7 @@ from paddle.fluid.dygraph.base import to_variable import numpy as np import paddle.fluid as fluid -from datasets.dataset import Dataset +from datasets import Dataset import transforms as T import models import utils.logging as logging @@ -112,7 +112,7 @@ def evaluate(model, eval_dataset.num_samples, total_steps)) for step, data in enumerate(data_generator()): images = np.array([d[0] for d in data]) - labels = np.array([d[1] for d in data]).astype('int64') + labels = np.array([d[2] for d in data]).astype('int64') images = to_variable(images) pred, _ = model(images, labels, mode='eval') @@ -134,17 +134,8 @@ def evaluate(model, logging.info("[EVAL] Kappa:{:.4f} ".format(conf_mat.kappa())) -def arrange_transform(transforms, mode='train'): - arrange_transform = T.ArrangeSegmenter - if type(transforms.transforms[-1]).__name__.startswith('Arrange'): - transforms.transforms[-1] = arrange_transform(mode=mode) - else: - transforms.transforms.append(arrange_transform(mode=mode)) - - def main(args): eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) - arrange_transform(eval_transforms, mode='eval') eval_dataset = Dataset( data_dir=args.data_dir, file_list=args.val_list,