diff --git a/example/mobilenetv2_imagenet2012/launch.py b/example/mobilenetv2_imagenet2012/launch.py index 5a8977c64b24decc3c407a1d943bd11579e97e70..bd28e20149ccabdfda6db8842bd2e60b93092a29 100644 --- a/example/mobilenetv2_imagenet2012/launch.py +++ b/example/mobilenetv2_imagenet2012/launch.py @@ -15,7 +15,6 @@ """launch train script""" import os import sys -import subprocess import json from argparse import ArgumentParser @@ -125,25 +124,19 @@ def main(): sys.stdout.flush() # spawn the processes - current_env = os.environ.copy() - current_env["RANK_SIZE"] = str(args.nproc_per_node) - if args.nproc_per_node > 1: - current_env["MINDSPORE_HCCL_CONFIG_PATH"] = table_fn - processes = [] - cmds = [] for rank_id in range(0, args.nproc_per_node): - current_env["RANK_ID"] = str(rank_id) - current_env["DEVICE_ID"] = visible_devices[rank_id] - cmd = [sys.executable, "-u"] - cmd.append(args.training_script) - cmd.extend(args.training_script_args) - process = subprocess.Popen(cmd, env=current_env) - processes.append(process) - cmds.append(cmd) - for process, cmd in zip(processes, cmds): - process.wait() - if process.returncode != 0: - raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) + device_id = visible_devices[rank_id] + device_dir = os.path.join(os.getcwd(), 'device{}'.format(rank_id)) + rank_process = 'export RANK_SIZE={} && export RANK_ID={} && export DEVICE_ID={} && '.format(args.nproc_per_node, + rank_id, device_id) + if args.nproc_per_node > 1: + rank_process += 'export MINDSPORE_HCCL_CONFIG_PATH={} && '.format(table_fn) + rank_process += 'export RANK_TABLE_FILE={} && '.format(table_fn) + rank_process += 'rm -rf {dir} && mkdir {dir} && cd {dir} && python {script} '.format(dir=device_dir, + script=args.training_script + ) + rank_process += ' '.join(args.training_script_args) + ' > log{}.log 2>&1 &'.format(rank_id) + os.system(rank_process) if __name__ == "__main__": diff --git a/example/mobilenetv2_imagenet2012/train.py b/example/mobilenetv2_imagenet2012/train.py index 584e89fe431f26e8f4efe7991ba6611c602ae02f..d97eab5f04bac4e69610df8563ec58be70181cdc 100644 --- a/example/mobilenetv2_imagenet2012/train.py +++ b/example/mobilenetv2_imagenet2012/train.py @@ -23,6 +23,7 @@ from lr_generator import get_lr from config import config from mindspore import context from mindspore import Tensor +from mindspore import nn from mindspore.model_zoo.mobilenet import mobilenet_v2 from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.nn.optim.momentum import Momentum @@ -110,16 +111,17 @@ class Monitor(Callback): if __name__ == '__main__': if run_distribute: - context.set_context(enable_hccl=True) context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, parameter_broadcast=True, mirror_mean=True) auto_parallel_context().set_all_reduce_fusion_split_indices([140]) init() - else: - context.set_context(enable_hccl=False) epoch_size = config.epoch_size net = mobilenet_v2(num_classes=config.num_classes) + net.add_flags_recursive(fp16=True) + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.add_flags_recursive(fp32=True) loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') print("train args: ", args_opt, "\ncfg: ", config, @@ -135,8 +137,7 @@ if __name__ == '__main__': opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay, config.loss_scale) - model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, amp_level='O0', - keep_batchnorm_fp32=False) + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale) cb = None if rank_id == 0: