md_dataset.py 3.6 KB
Newer Older
Y
yangyongjie 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset module."""
P
panbingao 已提交
16
import numpy as np
Y
yangyongjie 已提交
17 18 19 20
from PIL import Image
import mindspore.dataset as de
import mindspore.dataset.transforms.vision.c_transforms as C

U
modify  
unknown 已提交
21
from .ei_dataset import HwVocRawDataset
Y
yangyongjie 已提交
22 23 24
from .utils import custom_transforms as tr


Y
yangyongjie 已提交
25
class DataTransform:
Y
yangyongjie 已提交
26 27 28 29 30 31 32
    """Transform dataset for DeepLabV3."""

    def __init__(self, args, usage):
        self.args = args
        self.usage = usage

    def __call__(self, image, label):
Y
yangyongjie 已提交
33
        if self.usage == "train":
Y
yangyongjie 已提交
34
            return self._train(image, label)
Y
yangyongjie 已提交
35
        if self.usage == "eval":
Y
yangyongjie 已提交
36
            return self._eval(image, label)
Y
yangyongjie 已提交
37
        return None
Y
yangyongjie 已提交
38 39

    def _train(self, image, label):
Y
yangyongjie 已提交
40 41 42 43 44 45 46
        """
        Process training data.

        Args:
            image (list): Image data.
            label (list): Dataset label.
        """
Y
yangyongjie 已提交
47 48 49 50 51 52 53 54 55
        image = Image.fromarray(image)
        label = Image.fromarray(label)

        rsc_tr = tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size)
        image, label = rsc_tr(image, label)

        rhf_tr = tr.RandomHorizontalFlip()
        image, label = rhf_tr(image, label)

U
unknown 已提交
56 57
        image = np.array(image).astype(np.float32)
        label = np.array(label).astype(np.float32)
Y
yangyongjie 已提交
58 59 60 61

        return image, label

    def _eval(self, image, label):
Y
yangyongjie 已提交
62 63 64 65 66 67 68
        """
        Process eval data.

        Args:
            image (list): Image data.
            label (list): Dataset label.
        """
Y
yangyongjie 已提交
69 70 71 72 73 74
        image = Image.fromarray(image)
        label = Image.fromarray(label)

        fsc_tr = tr.FixScaleCrop(crop_size=self.args.crop_size)
        image, label = fsc_tr(image, label)

U
unknown 已提交
75 76
        image = np.array(image).astype(np.float32)
        label = np.array(label).astype(np.float32)
Y
yangyongjie 已提交
77 78 79 80

        return image, label


U
unknown 已提交
81
def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train", shuffle=True):
Y
yangyongjie 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95
    """
    Create Dataset for DeepLabV3.

    Args:
        args (dict): Train parameters.
        data_url (str): Dataset path.
        epoch_num (int): Epoch of dataset (default=1).
        batch_size (int): Batch size of dataset (default=1).
        usage (str): Whether is use to train or eval (default='train').

    Returns:
        Dataset.
    """
    # create iter dataset
U
modify  
unknown 已提交
96
    dataset = HwVocRawDataset(data_url, usage=usage)
Y
yangyongjie 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109
    dataset_len = len(dataset)

    # wrapped with GeneratorDataset
    dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=None)
    dataset.set_dataset_size(dataset_len)
    dataset = dataset.map(input_columns=["image", "label"], operations=DataTransform(args, usage=usage))

    channelswap_op = C.HWC2CHW()
    dataset = dataset.map(input_columns="image", operations=channelswap_op)

    # 1464 samples / batch_size 8 = 183 batches
    # epoch_num is num of steps
    # 3658 steps / 183 = 20 epochs
U
unknown 已提交
110
    if usage == "train" and shuffle:
Y
yangyongjie 已提交
111
        dataset = dataset.shuffle(1464)
Y
yangyongjie 已提交
112
    dataset = dataset.batch(batch_size, drop_remainder=(usage == "train"))
Y
yangyongjie 已提交
113 114 115 116
    dataset = dataset.repeat(count=epoch_num)
    dataset.map_model = 4

    return dataset