提交 4535faa2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5837 [AutoParallel]Rectification allreduce fusion api

Merge pull request !5837 from lichen/rectification_allreduce_fusion_api
...@@ -325,7 +325,8 @@ def _context(): ...@@ -325,7 +325,8 @@ def _context():
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
all_reduce_fusion_config=list)
def set_auto_parallel_context(**kwargs): def set_auto_parallel_context(**kwargs):
""" """
Set auto parallel context. Set auto parallel context.
...@@ -371,8 +372,9 @@ def set_auto_parallel_context(**kwargs): ...@@ -371,8 +372,9 @@ def set_auto_parallel_context(**kwargs):
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False. full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in
data parallel training in the benefit of time and memory saving. data parallel training in the benefit of time and memory saving.
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.
Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.
......
...@@ -462,7 +462,8 @@ _set_auto_parallel_context_func_map = { ...@@ -462,7 +462,8 @@ _set_auto_parallel_context_func_map = {
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().set_full_batch, "full_batch": auto_parallel_context().set_full_batch,
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer} "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices}
_get_auto_parallel_context_func_map = { _get_auto_parallel_context_func_map = {
...@@ -477,13 +478,15 @@ _get_auto_parallel_context_func_map = { ...@@ -477,13 +478,15 @@ _get_auto_parallel_context_func_map = {
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().get_full_batch, "full_batch": auto_parallel_context().get_full_batch,
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer} "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices}
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
all_reduce_fusion_config=list)
def _set_auto_parallel_context(**kwargs): def _set_auto_parallel_context(**kwargs):
""" """
...@@ -526,6 +529,7 @@ def _set_auto_parallel_context(**kwargs): ...@@ -526,6 +529,7 @@ def _set_auto_parallel_context(**kwargs):
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False. full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False. enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.
Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.
......
...@@ -47,8 +47,8 @@ def context_device_init(config): ...@@ -47,8 +47,8 @@ def context_device_init(config):
if config.run_distribute: if config.run_distribute:
context.set_auto_parallel_context(device_num=config.rank_size, context.set_auto_parallel_context(device_num=config.rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, gradients_mean=True) parameter_broadcast=True, gradients_mean=True,
auto_parallel_context().set_all_reduce_fusion_split_indices([140]) all_reduce_fusion_config=[140])
init() init()
else: else:
raise ValueError("Only support CPU, GPU and Ascend.") raise ValueError("Only support CPU, GPU and Ascend.")
......
...@@ -18,7 +18,6 @@ import argparse ...@@ -18,7 +18,6 @@ import argparse
import ast import ast
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
...@@ -78,9 +77,9 @@ if __name__ == '__main__': ...@@ -78,9 +77,9 @@ if __name__ == '__main__':
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True)
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50": if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160]) context.set_auto_parallel_context(all_reduce_fusion_config=[85, 150])
else: else:
auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313]) context.set_auto_parallel_context(all_reduce_fusion_config=[180, 313])
init() init()
# GPU target # GPU target
else: else:
...@@ -88,7 +87,7 @@ if __name__ == '__main__': ...@@ -88,7 +87,7 @@ if __name__ == '__main__':
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True)
if args_opt.net == "resnet50": if args_opt.net == "resnet50":
auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160]) context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
# create dataset # create dataset
......
...@@ -19,7 +19,6 @@ import argparse ...@@ -19,7 +19,6 @@ import argparse
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
...@@ -80,8 +79,7 @@ if __name__ == '__main__': ...@@ -80,8 +79,7 @@ if __name__ == '__main__':
init() init()
context.set_auto_parallel_context(device_num=args_opt.device_num, context.set_auto_parallel_context(device_num=args_opt.device_num,
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True, all_reduce_fusion_config=[107, 160])
auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160])
# define network # define network
net = resnet50_quant(class_num=config.class_num) net = resnet50_quant(class_num=config.class_num)
......
...@@ -20,7 +20,6 @@ import numpy as np ...@@ -20,7 +20,6 @@ import numpy as np
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import set_seed from mindspore.common import set_seed
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
...@@ -94,15 +93,13 @@ if __name__ == '__main__': ...@@ -94,15 +93,13 @@ if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True) context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True, all_reduce_fusion_config=[107])
auto_parallel_context().set_all_reduce_fusion_split_indices([107])
init() init()
# GPU target # GPU target
else: else:
init() init()
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True, all_reduce_fusion_config=[104])
auto_parallel_context().set_all_reduce_fusion_split_indices([107])
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
# create dataset # create dataset
......
...@@ -87,17 +87,16 @@ def run_pretrain(): ...@@ -87,17 +87,16 @@ def run_pretrain():
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num) device_num=device_num)
from mindspore.parallel._auto_parallel_context import auto_parallel_context
if bert_net_cfg.num_hidden_layers == 12: if bert_net_cfg.num_hidden_layers == 12:
if bert_net_cfg.use_relative_positions: if bert_net_cfg.use_relative_positions:
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217]) context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 87, 116, 145, 174, 203, 217])
else: else:
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) context.set_auto_parallel_context(all_reduce_fusion_config=[28, 55, 82, 109, 136, 163, 190, 205])
elif bert_net_cfg.num_hidden_layers == 24: elif bert_net_cfg.num_hidden_layers == 24:
if bert_net_cfg.use_relative_positions: if bert_net_cfg.use_relative_positions:
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421]) context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421])
else: else:
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) context.set_auto_parallel_context(all_reduce_fusion_config=[38, 93, 148, 203, 258, 313, 368, 397])
else: else:
rank = 0 rank = 0
device_num = 1 device_num = 1
......
...@@ -23,7 +23,6 @@ import numpy as np ...@@ -23,7 +23,6 @@ import numpy as np
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
...@@ -137,8 +136,8 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): ...@@ -137,8 +136,8 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
os.environ['RANK_SIZE'] = str(device_num) os.environ['RANK_SIZE'] = str(device_num)
if enable_hccl: if enable_hccl:
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, parameter_broadcast=True) gradients_mean=True, parameter_broadcast=True,
auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) all_reduce_fusion_config=[107, 160])
init() init()
# network # network
...@@ -240,8 +239,8 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl): ...@@ -240,8 +239,8 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
os.environ['RANK_SIZE'] = str(device_num) os.environ['RANK_SIZE'] = str(device_num)
if enable_hccl: if enable_hccl:
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, parameter_broadcast=True) gradients_mean=True, parameter_broadcast=True,
auto_parallel_context().set_all_reduce_fusion_split_indices([107]) all_reduce_fusion_config=[107])
init() init()
# network # network
......
...@@ -31,7 +31,6 @@ from mindspore import context ...@@ -31,7 +31,6 @@ from mindspore import context
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
...@@ -124,8 +123,8 @@ class CrossEntropyLoss(nn.Cell): ...@@ -124,8 +123,8 @@ class CrossEntropyLoss(nn.Cell):
if __name__ == '__main__': if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute: if not args_opt.do_eval and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
auto_parallel_context().set_all_reduce_fusion_split_indices([140]) all_reduce_fusion_config=[140])
init() init()
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
......
...@@ -30,7 +30,6 @@ from mindspore import context ...@@ -30,7 +30,6 @@ from mindspore import context
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
...@@ -154,8 +153,7 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size, ...@@ -154,8 +153,7 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size,
os.environ['RANK_SIZE'] = str(device_num) os.environ['RANK_SIZE'] = str(device_num)
if enable_hccl: if enable_hccl:
context.set_auto_parallel_context( context.set_auto_parallel_context(
device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL) device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, all_reduce_fusion_config=[140])
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init() init()
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
net = resnet50(batch_size, num_classes) net = resnet50(batch_size, num_classes)
......
...@@ -23,7 +23,6 @@ from mindspore.nn import TrainOneStepCell, WithLossCell ...@@ -23,7 +23,6 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import context from mindspore import context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
class Net(nn.Cell): class Net(nn.Cell):
"""Net definition""" """Net definition"""
...@@ -85,8 +84,8 @@ def test_lamb_compile(): ...@@ -85,8 +84,8 @@ def test_lamb_compile():
def test_lamb_split_fusion(): def test_lamb_split_fusion():
""" test_Lamb_split_fusion """ """ test_Lamb_split_fusion """
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True,
auto_parallel_context().set_all_reduce_fusion_split_indices([2, 4, 6, 8]) all_reduce_fusion_config=[2, 4, 6, 8])
inputs = Tensor(np.ones([32, 128]).astype(np.float32)) inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net() net = Net()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册