From eca6c9c62ca0601b2abc2c5b43d578d7365d7812 Mon Sep 17 00:00:00 2001 From: moran Date: Mon, 7 Sep 2020 14:21:56 +0800 Subject: [PATCH] Fix wizard template module to fit new operator API --- .../wizard/conf/templates/network/alexnet/train.py-tpl | 4 ++-- mindinsight/wizard/conf/templates/network/lenet/train.py-tpl | 2 +- .../wizard/conf/templates/network/resnet50/train.py-tpl | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mindinsight/wizard/conf/templates/network/alexnet/train.py-tpl b/mindinsight/wizard/conf/templates/network/alexnet/train.py-tpl index 0cc7b35..f6b3be3 100644 --- a/mindinsight/wizard/conf/templates/network/alexnet/train.py-tpl +++ b/mindinsight/wizard/conf/templates/network/alexnet/train.py-tpl @@ -60,14 +60,14 @@ if __name__ == "__main__": device_id = int(os.getenv('DEVICE_ID')) context.set_context(device_id=device_id, enable_auto_mixed_precision=True) context.set_auto_parallel_context(device_num=args.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) + gradients_mean=True) init() # GPU target else: init("nccl") context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) + gradients_mean=True) ckpt_save_dir = cfg.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" diff --git a/mindinsight/wizard/conf/templates/network/lenet/train.py-tpl b/mindinsight/wizard/conf/templates/network/lenet/train.py-tpl index f831109..a696d5c 100644 --- a/mindinsight/wizard/conf/templates/network/lenet/train.py-tpl +++ b/mindinsight/wizard/conf/templates/network/lenet/train.py-tpl @@ -60,7 +60,7 @@ if __name__ == "__main__": raise ValueError('Distribute running is no supported on %s' % args.device_target) context.reset_auto_parallel_context() context.set_auto_parallel_context(device_num=args.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) + gradients_mean=True) data_path = args.dataset_path do_train = True diff --git a/mindinsight/wizard/conf/templates/network/resnet50/train.py-tpl b/mindinsight/wizard/conf/templates/network/resnet50/train.py-tpl index 7ca6ed2..2ec7153 100644 --- a/mindinsight/wizard/conf/templates/network/resnet50/train.py-tpl +++ b/mindinsight/wizard/conf/templates/network/resnet50/train.py-tpl @@ -65,7 +65,7 @@ if __name__ == '__main__': device_id = int(os.getenv('DEVICE_ID')) context.set_context(device_id=device_id, enable_auto_mixed_precision=True) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) + gradients_mean=True) auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) init() @@ -73,7 +73,7 @@ if __name__ == '__main__': else: init("nccl") context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) + gradients_mean=True) ckpt_save_dir = cfg.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" # create dataset -- GitLab