diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc index 3ba055880cd3232a074a7a68f792fd9fa80c8ec9..eafb2bb59d2cc3c594f1d5b05877bbcd974f9087 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc @@ -100,7 +100,10 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic auto parallel_context = parallel::ParallelContext::GetInstance(); MS_EXCEPTION_IF_NULL(parallel_context); - const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group); + std::vector split_indices; + if (!parallel_context->enable_parallel_optimizer()) { + split_indices = parallel_context->GetAllReduceFusionSplitIndices(group); + } size_t segments = 0; if (split_indices.size() != 0) { diff --git a/mindspore/context.py b/mindspore/context.py index 0de6084caf520674a7dd4fee44657bb61e5fbfd2..551ec7b79a9ee05631c94601aeb39d0006ad7385 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -443,7 +443,7 @@ def _context(): @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, - strategy_ckpt_save_file=str, full_batch=bool) + strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) def set_auto_parallel_context(**kwargs): """ Set auto parallel context. @@ -487,6 +487,9 @@ def set_auto_parallel_context(**kwargs): strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' full_batch (bool): Whether to load the whole batch on each device. Default: False. + enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in + data parallel training in the benefit of time and memory saving. + Raises: ValueError: If input key is not attribute in auto parallel context. @@ -532,6 +535,7 @@ def reset_auto_parallel_context(): - parameter_broadcast: False. - strategy_ckpt_load_file: "". - strategy_ckpt_save_file: "". + - enable_parallel_optimizer: False. """ _reset_auto_parallel_context() diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index cdf1565f349bc5cc2f22d80f85971b545b4c7c57..4b2ca1aee38fb633b641ebdb3e8b37e501808d92 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -28,8 +28,8 @@ from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from mindspore import log as logger from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode -from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.train.parallel_utils import ParallelMode +from mindspore import context __all__ = ['Optimizer'] @@ -157,13 +157,12 @@ class Optimizer(Cell): self.param_length = len(self.parameters) self.map_ = C.Map() - use_parallel = auto_parallel_context().get_enable_parallel_optimizer() + use_parallel = context.get_auto_parallel_context("enable_parallel_optimizer") self.use_parallel = use_parallel if use_parallel: if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]: raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) - if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL, - ParallelMode.AUTO_PARALLEL]: + if _get_parallel_mode() != ParallelMode.DATA_PARALLEL: raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format (_get_parallel_mode())) self.dev_num = _get_device_num() @@ -175,6 +174,7 @@ class Optimizer(Cell): self.param_names = [] for param in self.parameters: self.param_names.append(param.name) + else: self.optim_filter = (True,) * self.param_length diff --git a/mindspore/nn/wrap/grad_reducer.py b/mindspore/nn/wrap/grad_reducer.py index 9354b42e55eff0fd0896c3439ef40cc0943edfdd..3d754977d453351242fb439f6d1fc045a2156773 100644 --- a/mindspore/nn/wrap/grad_reducer.py +++ b/mindspore/nn/wrap/grad_reducer.py @@ -13,107 +13,95 @@ # limitations under the License. # ============================================================================ """grad reducer cell for distributed training""" +from mindspore import context from mindspore.nn.cell import Cell from mindspore.communication.management import GlobalComm, get_group_size from mindspore.ops import functional as F, composite as C, operations as P -from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp, AllGather +from mindspore.ops.operations.comm_ops import AllReduce, AllGather +from mindspore.parallel._auto_parallel_context import auto_parallel_context import mindspore.common.dtype as mstype reduce_opt = C.MultitypeFuncGraph("reduce_opt") -_all_reduce = AllReduce() -_all_gather = None - -def _init_optimizer_communication(): - global _all_reduce - global _all_gather - - _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) - _all_reduce.add_prim_attr('fusion', 1) - _all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP) - - -@reduce_opt.register("Function", "Number", "Bool", "Tensor") -def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad): +def _init_allreduce_operators(length): + """ initialize allreduce communication operators""" + is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer") + split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() + if is_parallel_optimizer and split_indices: + group = 1 + fusion = () + for i in range(length): + fusion = fusion + (group,) + if split_indices[group - 1] <= i + 1: + if group >= len(split_indices): + continue + group = group + 1 + index = tuple(range(1, length + 1)) + else: + fusion = (1,) * length + index = (0,) * length + opt_list = () + for i in range(length): + opt = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) + opt.add_prim_attr('fusion', fusion[i]) + opt.add_prim_attr('index', index[i]) + opt_list = opt_list + (opt,) + return opt_list + + +@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function") +def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce): """ - Apply mean and allreduce on gradient. Allreduce is a communication operation used for distributed deep learning. + Apply allreduce on gradient. Args: - mul (Primitive): Div operation. degree (int): The mean coefficient. + mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. + allgather (Primitive): The communication operator for sparse gradients. allreduce_filter (bool): When it is true, allreduce would apply. grad (Tensor): The gradient tensor before operation. + allreduce (Primitive): The communication operator for gradients. Returns: Tensor, the gradient tensor after operation. """ if allreduce_filter: - degree = F.scalar_cast(degree, F.dtype(grad)) - grad = _all_reduce(grad) - cast_op = P.Cast() - return mul(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad))) + grad = allreduce(grad) + if mean: + degree = F.scalar_cast(degree, F.dtype(grad)) + cast_op = P.Cast() + mul_op = P.Mul() + grad = mul_op(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad))) + return grad return grad -@reduce_opt.register("Function", "Number", "Bool", "Tuple") -def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad): +@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tuple", "Function") +def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, grad, allreduce): """ - Apply mean and allgather on gradient instead of allreduce for sparse feature. + Apply allgather on gradient instead of allreduce for sparse feature. Allgather is a communication operation used for distributed deep learning. Args: - mul (Primitive): Div operation. degree (int): The mean coefficient. + mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. + allgather (Primitive): The communication operator for sparse gradients. allreduce_filter (bool): When it is true, allgather would apply. - grad (Tuple): The indices, gradient tensor and tensor_shape before operation. + grad (tuple): The indices, gradient tensor and tensor_shape before operation. + allreduce (Primitive): The communication operator for gradients. Returns: Tuple, include indices, the gradient tensor and tensor_shape after operation. """ if allreduce_filter: - indices = _all_gather(grad[0]) - degree = F.scalar_cast(degree, F.dtype(grad[1])) - dout = _all_gather(grad[1]) - cast_op = P.Cast() - dout = mul(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) - grad = (indices, dout, grad[2]) - return grad - - -@reduce_opt.register("Bool", "Tensor") -def _tensors_allreduce(allreduce_filter, grad): - """ - Apply allreduce on gradient. - - Args: - allreduce_filter (bool): When it is true, allreduce would apply. - grad (Tensor): The gradient tensor before operation. - - Returns: - Tensor, the gradient tensor after operation. - """ - if allreduce_filter: - return _all_reduce(grad) - return grad - - -@reduce_opt.register("Bool", "Tuple") -def _tensors_allreduce_with_sparse(allreduce_filter, grad): - """ - Apply mean and allgather on gradient instead of allreduce for sparse feature. - Allgather is a communication operation used for distributed deep learning. - - Args: - allreduce_filter (bool): When it is true, allgather would apply. - grad (Tuple): The indices, gradient tensor and tensor_shape before operation. - - Returns: - Tuple, include indices, the gradient tensor and tensor_shape after operation. - """ - if allreduce_filter: - indices = _all_gather(grad[0]) - dout = _all_gather(grad[1]) + indices = allgather(grad[0]) + dout = allgather(grad[1]) + if mean: + degree = F.scalar_cast(degree, F.dtype(grad[1])) + cast_op = P.Cast() + mul_op = P.Mul() + dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) grad = (indices, dout, grad[2]) return grad @@ -259,7 +247,6 @@ class DistributedGradReducer(Cell): def __init__(self, parameters, mean=True, degree=None): super(DistributedGradReducer, self).__init__(auto_prefix=False) self.map_ = C.Map() - self.mul = P.Mul() if degree is None: self.degree = get_group_size() else: @@ -268,7 +255,8 @@ class DistributedGradReducer(Cell): self.degree = degree self.mean = mean self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) - _init_optimizer_communication() + self.opt_list = _init_allreduce_operators(len(parameters)) + self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP) def construct(self, grads): """ @@ -284,11 +272,8 @@ class DistributedGradReducer(Cell): """ datatypes = self.map_(F.partial(_get_datatype), grads) grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) - - if self.mean: - new_grad = self.map_(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads) - else: - new_grad = self.map_(F.partial(reduce_opt), self.allreduce_filter, grads) + new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), + self.allreduce_filter, grads, self.opt_list) new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) return new_grad diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 93fe23385575096d58bb243605692c93806cec60..3f6ce21cb9436d93acddb4125cfc2975aa490fca 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -513,7 +513,7 @@ def _set_auto_parallel_context(**kwargs): strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' full_batch (bool): Whether to load the whole batch on each device. Default: False. - enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False. + enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False. Raises: ValueError: If input key is not attribute in auto parallel context. diff --git a/tests/ut/python/parallel/test_parallel_optimizer.py b/tests/ut/python/parallel/test_parallel_optimizer.py index 6663e34871efb84cf99a4db6f8de85683c3a75ae..ee9291fb98a4887f4a6dcfaa68cd332052ffff2d 100644 --- a/tests/ut/python/parallel/test_parallel_optimizer.py +++ b/tests/ut/python/parallel/test_parallel_optimizer.py @@ -22,7 +22,6 @@ from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb from mindspore.ops import operations as P -from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore import context @@ -54,8 +53,7 @@ class Net(nn.Cell): def test_AdamWeightDecayDynamicLR(): """ test_AdamWeightDecayDynamicLR """ - auto_parallel_context().set_enable_parallel_optimizer(True) - context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2) + context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) inputs = Tensor(np.ones([32, 128]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32)) net = Net() @@ -70,8 +68,7 @@ def test_AdamWeightDecayDynamicLR(): def test_AdamWeightDecay(): """ test_AdamWeightDecayDynamicLR """ - auto_parallel_context().set_enable_parallel_optimizer(True) - context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2) + context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) inputs = Tensor(np.ones([32, 128]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32)) net = Net() @@ -86,8 +83,7 @@ def test_AdamWeightDecay(): def test_lamb_compile(): """ test_Lamb_compile """ - auto_parallel_context().set_enable_parallel_optimizer(True) - context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2) + context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) inputs = Tensor(np.ones([32, 128]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32)) net = Net() @@ -102,7 +98,7 @@ def test_lamb_compile(): def test_edge_case(): """ test_edge_case """ - auto_parallel_context().set_enable_parallel_optimizer(True) + context.set_auto_parallel_context(enable_parallel_optimizer=True) net = Net() with pytest.raises(RuntimeError): context.set_auto_parallel_context(parallel_mode="stand_alone") diff --git a/tests/ut/python/parallel/test_set_auto_parallel_context.py b/tests/ut/python/parallel/test_set_auto_parallel_context.py index c476b0cebc31ed0422ae282dddf4f7742d4fe4e8..19187cb262c10b76f8ed22304c1317888ef94c54 100644 --- a/tests/ut/python/parallel/test_set_auto_parallel_context.py +++ b/tests/ut/python/parallel/test_set_auto_parallel_context.py @@ -81,8 +81,8 @@ def test_set_auto_parallel_context(): with pytest.raises(ValueError): set_algo_parameters(tensor_slice_align_size=1025) - auto_parallel_context().set_enable_parallel_optimizer(True) - assert auto_parallel_context().get_enable_parallel_optimizer() is True + context.set_auto_parallel_context(enable_parallel_optimizer=True) + assert context.get_auto_parallel_context("enable_parallel_optimizer") assert not auto_parallel_context().get_all_reduce_fusion_split_indices()