提交 68c3c73f 编写于 作者: P panfengfeng

update mobilenetV2 dataset codes

上级 7ffcc606
......@@ -37,24 +37,31 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
if platform == "Ascend":
rank_size = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
if rank_size == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
if do_train:
if rank_size == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id)
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False)
elif platform == "GPU":
if do_train:
from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False)
else:
raise ValueError("Unsupport platform.")
resize_height = config.image_height
resize_width = config.image_width
buffer_size = 1000
if do_train:
buffer_size = 20480
# apply shuffle operations
ds = ds.shuffle(buffer_size=buffer_size)
# define map operations
decode_op = C.Decode()
......@@ -63,23 +70,23 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
resize_op = C.Resize((256, 256))
center_crop = C.CenterCrop(resize_width)
rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
random_color_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])
change_swap_op = C.HWC2CHW()
transform_uniform = [horizontal_flip_op, random_color_op]
uni_aug = C.UniformAugment(operations=transform_uniform, num_ops=2)
if do_train:
trans = [resize_crop_op, horizontal_flip_op, rescale_op, normalize_op, change_swap_op]
trans = [resize_crop_op, uni_aug, normalize_op, change_swap_op]
else:
trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=16)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
# apply shuffle operations
ds = ds.shuffle(buffer_size=buffer_size)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册