dataset.py 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
E
Eric 已提交
21
import mindspore.dataset.vision.c_transforms as C
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import init, get_rank, get_group_size

def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
    """
    create a train or evaluate cifar10 dataset for resnet50
    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1
        batch_size(int): the batch size of dataset. Default: 32
        target(str): the device target. Default: Ascend

    Returns:
        dataset
    """
    if target == "Ascend":
39
        device_num, rank_id = _get_rank_info()
40
    else:
L
lichenever 已提交
41
        init()
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
        rank_id = get_rank()
        device_num = get_group_size()

    if device_num == 1:
        ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
    else:
        ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
                               num_shards=device_num, shard_id=rank_id)

    # define map operations
    trans = []
    if do_train:
        trans += [
            C.RandomCrop((32, 32), (4, 4, 4, 4)),
            C.RandomHorizontalFlip(prob=0.5)
        ]

    trans += [
        C.Resize((224, 224)),
        C.Rescale(1.0 / 255.0, 0.0),
        C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        C.HWC2CHW()
    ]

    type_cast_op = C2.TypeCast(mstype.int32)

    ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
    ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)

    # apply batch operations
    ds = ds.batch(batch_size, drop_remainder=True)
    # apply dataset repeat operation
    ds = ds.repeat(repeat_num)

    return ds


def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
    """
    create a train or eval imagenet2012 dataset for resnet50

    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1
        batch_size(int): the batch size of dataset. Default: 32
        target(str): the device target. Default: Ascend

    Returns:
        dataset
    """
    if target == "Ascend":
94
        device_num, rank_id = _get_rank_info()
95
    else:
L
lichenever 已提交
96
        init()
97 98 99 100
        rank_id = get_rank()
        device_num = get_group_size()

    if device_num == 1:
N
nhussain 已提交
101
        ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
102
    else:
N
nhussain 已提交
103 104
        ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
                                   num_shards=device_num, shard_id=rank_id)
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120

    image_size = 224
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

    # define map operations
    if do_train:
        trans = [
            C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            C.RandomHorizontalFlip(prob=0.5),
            C.Normalize(mean=mean, std=std),
            C.HWC2CHW()
        ]
    else:
        trans = [
            C.Decode(),
121
            C.Resize(256),
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
            C.CenterCrop(image_size),
            C.Normalize(mean=mean, std=std),
            C.HWC2CHW()
        ]

    type_cast_op = C2.TypeCast(mstype.int32)

    ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
    ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)

    # apply batch operations
    ds = ds.batch(batch_size, drop_remainder=True)

    # apply dataset repeat operation
    ds = ds.repeat(repeat_num)

    return ds


R
RobinGrosman 已提交
141
def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
142 143 144 145 146 147 148 149 150 151 152
    """
    create a train or eval imagenet2012 dataset for resnet101
    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1
        batch_size(int): the batch size of dataset. Default: 32

    Returns:
        dataset
    """
153
    device_num, rank_id = _get_rank_info()
154 155

    if device_num == 1:
N
nhussain 已提交
156
        ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
157
    else:
N
nhussain 已提交
158 159
        ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
                                   num_shards=device_num, shard_id=rank_id)
R
RobinGrosman 已提交
160 161 162
    image_size = 224
    mean = [0.475 * 255, 0.451 * 255, 0.392 * 255]
    std = [0.275 * 255, 0.267 * 255, 0.278 * 255]
163 164 165

    # define map operations
    if do_train:
R
RobinGrosman 已提交
166 167 168 169 170 171
        trans = [
            C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            C.RandomHorizontalFlip(rank_id/ (rank_id +1)),
            C.Normalize(mean=mean, std=std),
            C.HWC2CHW()
        ]
172
    else:
R
RobinGrosman 已提交
173 174 175 176 177 178 179
        trans = [
            C.Decode(),
            C.Resize(256),
            C.CenterCrop(image_size),
            C.Normalize(mean=mean, std=std),
            C.HWC2CHW()
        ]
180 181 182 183 184 185 186 187 188 189 190 191

    type_cast_op = C2.TypeCast(mstype.int32)

    ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
    ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)

    # apply batch operations
    ds = ds.batch(batch_size, drop_remainder=True)
    # apply dataset repeat operation
    ds = ds.repeat(repeat_num)

    return ds
192

Q
qujianwei 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
def create_dataset4(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
    """
    create a train or eval imagenet2012 dataset for se-resnet50

    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1
        batch_size(int): the batch size of dataset. Default: 32
        target(str): the device target. Default: Ascend

    Returns:
        dataset
    """
    if target == "Ascend":
        device_num, rank_id = _get_rank_info()
    if device_num == 1:
N
nhussain 已提交
210
        ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True)
Q
qujianwei 已提交
211
    else:
N
nhussain 已提交
212 213
        ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True,
                                   num_shards=device_num, shard_id=rank_id)
Q
qujianwei 已提交
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
    image_size = 224
    mean = [123.68, 116.78, 103.94]
    std = [1.0, 1.0, 1.0]

    # define map operations
    if do_train:
        trans = [
            C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            C.RandomHorizontalFlip(prob=0.5),
            C.Normalize(mean=mean, std=std),
            C.HWC2CHW()
        ]
    else:
        trans = [
            C.Decode(),
            C.Resize(292),
            C.CenterCrop(256),
            C.Normalize(mean=mean, std=std),
            C.HWC2CHW()
        ]

    type_cast_op = C2.TypeCast(mstype.int32)
    ds = ds.map(input_columns="image", num_parallel_workers=12, operations=trans)
    ds = ds.map(input_columns="label", num_parallel_workers=12, operations=type_cast_op)

    # apply batch operations
    ds = ds.batch(batch_size, drop_remainder=True)

    # apply dataset repeat operation
    ds = ds.repeat(repeat_num)

    return ds
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260

def _get_rank_info():
    """
    get rank size and rank id
    """
    rank_size = int(os.environ.get("RANK_SIZE", 1))

    if rank_size > 1:
        rank_size = get_group_size()
        rank_id = get_rank()
    else:
        rank_size = 1
        rank_id = 0

    return rank_size, rank_id