提交 7d30d27e 编写于 作者: W wandongdong

fix performance bug and codedex

上级 0d7eb2a0
......@@ -40,37 +40,35 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
rank_id = int(os.getenv("RANK_ID"))
if rank_size == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=16, shuffle=True)
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=16, shuffle=True,
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id)
resize_height = config.image_height
resize_width = config.image_width
rescale = 1.0 / 255.0
shift = 0.0
buffer_size = 1000
# define map operations
decode_op = C.Decode()
resize_crop_op = C.RandomResizedCrop(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))
resize_crop_op = C.RandomCropDecodeResize(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))
horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)
resize_op = C.Resize((256, 256))
center_crop = C.CenterCrop(resize_width)
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])
change_swap_op = C.HWC2CHW()
if do_train:
trans = [decode_op, resize_crop_op, horizontal_flip_op, rescale_op, normalize_op, change_swap_op]
trans = [resize_crop_op, horizontal_flip_op, rescale_op, normalize_op, change_swap_op]
else:
trans = [decode_op, resize_op, center_crop, rescale_op, normalize_op, change_swap_op]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=trans)
ds = ds.map(input_columns="label", operations=type_cast_op)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
# apply shuffle operations
ds = ds.shuffle(buffer_size=buffer_size)
......
......@@ -16,9 +16,9 @@
import os
import sys
import json
import subprocess
from argparse import ArgumentParser
def parse_args():
"""
parse args .
......@@ -124,6 +124,8 @@ def main():
sys.stdout.flush()
# spawn the processes
processes = []
cmds = []
for rank_id in range(0, args.nproc_per_node):
device_id = visible_devices[rank_id]
device_dir = os.path.join(os.getcwd(), 'device{}'.format(rank_id))
......@@ -136,7 +138,13 @@ def main():
script=args.training_script
)
rank_process += ' '.join(args.training_script_args) + ' > log{}.log 2>&1 &'.format(rank_id)
os.system(rank_process)
process = subprocess.Popen(rank_process, shell=True)
processes.append(process)
cmds.append(rank_process)
for process, cmd in zip(processes, cmds):
process.wait()
if process.returncode != 0:
raise subprocess.CalledProcessError(returncode=process, cmd=cmd)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册