提交 d85d2f5a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!818 fix launch bug

Merge pull request !818 from wandongdong/master
......@@ -15,7 +15,6 @@
"""launch train script"""
import os
import sys
import subprocess
import json
from argparse import ArgumentParser
......@@ -125,25 +124,19 @@ def main():
sys.stdout.flush()
# spawn the processes
current_env = os.environ.copy()
current_env["RANK_SIZE"] = str(args.nproc_per_node)
if args.nproc_per_node > 1:
current_env["MINDSPORE_HCCL_CONFIG_PATH"] = table_fn
processes = []
cmds = []
for rank_id in range(0, args.nproc_per_node):
current_env["RANK_ID"] = str(rank_id)
current_env["DEVICE_ID"] = visible_devices[rank_id]
cmd = [sys.executable, "-u"]
cmd.append(args.training_script)
cmd.extend(args.training_script_args)
process = subprocess.Popen(cmd, env=current_env)
processes.append(process)
cmds.append(cmd)
for process, cmd in zip(processes, cmds):
process.wait()
if process.returncode != 0:
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
device_id = visible_devices[rank_id]
device_dir = os.path.join(os.getcwd(), 'device{}'.format(rank_id))
rank_process = 'export RANK_SIZE={} && export RANK_ID={} && export DEVICE_ID={} && '.format(args.nproc_per_node,
rank_id, device_id)
if args.nproc_per_node > 1:
rank_process += 'export MINDSPORE_HCCL_CONFIG_PATH={} && '.format(table_fn)
rank_process += 'export RANK_TABLE_FILE={} && '.format(table_fn)
rank_process += 'rm -rf {dir} && mkdir {dir} && cd {dir} && python {script} '.format(dir=device_dir,
script=args.training_script
)
rank_process += ' '.join(args.training_script_args) + ' > log{}.log 2>&1 &'.format(rank_id)
os.system(rank_process)
if __name__ == "__main__":
......
......@@ -23,6 +23,7 @@ from lr_generator import get_lr
from config import config
from mindspore import context
from mindspore import Tensor
from mindspore import nn
from mindspore.model_zoo.mobilenet import mobilenet_v2
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum
......@@ -110,16 +111,17 @@ class Monitor(Callback):
if __name__ == '__main__':
if run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
epoch_size = config.epoch_size
net = mobilenet_v2(num_classes=config.num_classes)
net.add_flags_recursive(fp16=True)
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Dense):
cell.add_flags_recursive(fp32=True)
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
print("train args: ", args_opt, "\ncfg: ", config,
......@@ -135,8 +137,7 @@ if __name__ == '__main__':
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, amp_level='O0',
keep_batchnorm_fp32=False)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale)
cb = None
if rank_id == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册