diff --git a/example/mobilenetv2_imagenet2012/launch.py b/example/mobilenetv2_imagenet2012/launch.py index 5a8977c64b24decc3c407a1d943bd11579e97e70..bd28e20149ccabdfda6db8842bd2e60b93092a29 100644 --- a/example/mobilenetv2_imagenet2012/launch.py +++ b/example/mobilenetv2_imagenet2012/launch.py @@ -15,7 +15,6 @@ """launch train script""" import os import sys -import subprocess import json from argparse import ArgumentParser @@ -125,25 +124,19 @@ def main(): sys.stdout.flush() # spawn the processes - current_env = os.environ.copy() - current_env["RANK_SIZE"] = str(args.nproc_per_node) - if args.nproc_per_node > 1: - current_env["MINDSPORE_HCCL_CONFIG_PATH"] = table_fn - processes = [] - cmds = [] for rank_id in range(0, args.nproc_per_node): - current_env["RANK_ID"] = str(rank_id) - current_env["DEVICE_ID"] = visible_devices[rank_id] - cmd = [sys.executable, "-u"] - cmd.append(args.training_script) - cmd.extend(args.training_script_args) - process = subprocess.Popen(cmd, env=current_env) - processes.append(process) - cmds.append(cmd) - for process, cmd in zip(processes, cmds): - process.wait() - if process.returncode != 0: - raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) + device_id = visible_devices[rank_id] + device_dir = os.path.join(os.getcwd(), 'device{}'.format(rank_id)) + rank_process = 'export RANK_SIZE={} && export RANK_ID={} && export DEVICE_ID={} && '.format(args.nproc_per_node, + rank_id, device_id) + if args.nproc_per_node > 1: + rank_process += 'export MINDSPORE_HCCL_CONFIG_PATH={} && '.format(table_fn) + rank_process += 'export RANK_TABLE_FILE={} && '.format(table_fn) + rank_process += 'rm -rf {dir} && mkdir {dir} && cd {dir} && python {script} '.format(dir=device_dir, + script=args.training_script + ) + rank_process += ' '.join(args.training_script_args) + ' > log{}.log 2>&1 &'.format(rank_id) + os.system(rank_process) if __name__ == "__main__": diff --git a/example/mobilenetv2_imagenet2012/train.py b/example/mobilenetv2_imagenet2012/train.py index 584e89fe431f26e8f4efe7991ba6611c602ae02f..5152705749f311ccba2e3e029e6d4b56c79111fe 100644 --- a/example/mobilenetv2_imagenet2012/train.py +++ b/example/mobilenetv2_imagenet2012/train.py @@ -135,8 +135,7 @@ if __name__ == '__main__': opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay, config.loss_scale) - model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, amp_level='O0', - keep_batchnorm_fp32=False) + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale) cb = None if rank_id == 0: