diff --git a/example/resnet50_quant/eval.py b/example/resnet50_quant/eval.py index e7ec439d69a8c028021c3f5d5bc0139cb6f1fd0c..df0aec3c422746e1a1eb448015887bdc6746179e 100755 --- a/example/resnet50_quant/eval.py +++ b/example/resnet50_quant/eval.py @@ -17,7 +17,7 @@ eval. """ import os import argparse -from src.dataset import create_dataset +from src.dataset import create_dataset_py from src.config import config from src.crossentropy import CrossEntropy from src.utils import _load_param_into_net @@ -49,8 +49,8 @@ if __name__ == '__main__': loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) if args_opt.do_eval: - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size, - target=target) + dataset = create_dataset_py(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size, + target=target) step_size = dataset.get_dataset_size() if args_opt.checkpoint_path: diff --git a/example/resnet50_quant/src/dataset.py b/example/resnet50_quant/src/dataset.py index fc767a90233f9599acb361c34b57b11a91e1ee02..a77e1e81cfd9634b285f6606fdb992b3a8e14fbd 100755 --- a/example/resnet50_quant/src/dataset.py +++ b/example/resnet50_quant/src/dataset.py @@ -20,6 +20,7 @@ import mindspore.common.dtype as mstype import mindspore.dataset.engine as de import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.c_transforms as C2 +import mindspore.dataset.transforms.vision.py_transforms as P from mindspore.communication.management import init, get_rank, get_group_size def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): @@ -83,3 +84,63 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target=" ds = ds.repeat(repeat_num) return ds + +def create_dataset_py(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): + """ + create a train or eval dataset + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + + Returns: + dataset + """ + if target == "Ascend": + device_num = int(os.getenv("RANK_SIZE")) + rank_id = int(os.getenv("RANK_ID")) + else: + init("nccl") + rank_id = get_rank() + device_num = get_group_size() + + if do_train: + if device_num == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False) + + image_size = 224 + + # define map operations + decode_op = P.Decode() + resize_crop_op = P.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)) + horizontal_flip_op = P.RandomHorizontalFlip(prob=0.5) + + resize_op = P.Resize(256) + center_crop = P.CenterCrop(image_size) + to_tensor = P.ToTensor() + normalize_op = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + # define map operations + if do_train: + trans = [decode_op, resize_crop_op, horizontal_flip_op, to_tensor, normalize_op] + else: + trans = [decode_op, resize_op, center_crop, to_tensor, normalize_op] + + compose = P.ComposeOp(trans) + ds = ds.map(input_columns="image", operations=compose(), num_parallel_workers=8, python_multiprocessing=True) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds diff --git a/example/resnet50_quant/train.py b/example/resnet50_quant/train.py index 5a103af2b6a5b012660f563b6d75c71e37644749..77be1e9f088925780d871a3edc55c92dbf76f23c 100755 --- a/example/resnet50_quant/train.py +++ b/example/resnet50_quant/train.py @@ -27,7 +27,7 @@ from mindspore.communication.management import init import mindspore.nn as nn import mindspore.common.initializer as weight_init from models.resnet_quant import resnet50_quant -from src.dataset import create_dataset +from src.dataset import create_dataset_py from src.lr_generator import get_lr from src.config import config from src.crossentropy import CrossEntropy @@ -85,8 +85,8 @@ if __name__ == '__main__': loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) if args_opt.do_train: - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, - repeat_num=epoch_size, batch_size=config.batch_size, target=target) + dataset = create_dataset_py(dataset_path=args_opt.dataset_path, do_train=True, + repeat_num=epoch_size, batch_size=config.batch_size, target=target) step_size = dataset.get_dataset_size() loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)